continual-foragax 0.39.0__py3-none-any.whl → 0.41.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {continual_foragax-0.39.0.dist-info → continual_foragax-0.41.0.dist-info}/METADATA +1 -1
- {continual_foragax-0.39.0.dist-info → continual_foragax-0.41.0.dist-info}/RECORD +9 -9
- foragax/env.py +515 -409
- foragax/objects.py +10 -7
- foragax/registry.py +0 -3
- foragax/rendering.py +9 -33
- {continual_foragax-0.39.0.dist-info → continual_foragax-0.41.0.dist-info}/WHEEL +0 -0
- {continual_foragax-0.39.0.dist-info → continual_foragax-0.41.0.dist-info}/entry_points.txt +0 -0
- {continual_foragax-0.39.0.dist-info → continual_foragax-0.41.0.dist-info}/top_level.txt +0 -0
foragax/env.py
CHANGED
|
@@ -6,7 +6,7 @@ Source: https://github.com/andnp/Forager
|
|
|
6
6
|
from dataclasses import dataclass
|
|
7
7
|
from enum import IntEnum
|
|
8
8
|
from functools import partial
|
|
9
|
-
from typing import Any, Dict,
|
|
9
|
+
from typing import Any, Dict, Tuple, Union
|
|
10
10
|
|
|
11
11
|
import jax
|
|
12
12
|
import jax.numpy as jnp
|
|
@@ -26,6 +26,15 @@ from foragax.rendering import apply_true_borders
|
|
|
26
26
|
from foragax.weather import get_temperature
|
|
27
27
|
|
|
28
28
|
|
|
29
|
+
ID_DTYPE = jnp.int32 # Object type ID (0 = empty, >0 = object type)
|
|
30
|
+
TIMER_DTYPE = jnp.int32 # Respawn countdown (0 = no timer, >0 = countdown)
|
|
31
|
+
TIME_DTYPE = jnp.int32 # Timesteps (spawn time, current time)
|
|
32
|
+
PARAM_DTYPE = jnp.float16 # Per-instance object parameters
|
|
33
|
+
COLOR_DTYPE = jnp.uint8 # RGB color channels (0-255)
|
|
34
|
+
BIOME_ID_DTYPE = jnp.int16 # Biome assignment for each cell
|
|
35
|
+
REWARD_DTYPE = jnp.float32 # Reward values
|
|
36
|
+
|
|
37
|
+
|
|
29
38
|
class Actions(IntEnum):
|
|
30
39
|
DOWN = 0
|
|
31
40
|
RIGHT = 1
|
|
@@ -74,14 +83,14 @@ class ObjectState:
|
|
|
74
83
|
"""Create an empty ObjectState for the given grid size."""
|
|
75
84
|
h, w = size[1], size[0]
|
|
76
85
|
return cls(
|
|
77
|
-
object_id=jnp.zeros((h, w), dtype=
|
|
78
|
-
respawn_timer=jnp.zeros((h, w), dtype=
|
|
79
|
-
respawn_object_id=jnp.zeros((h, w), dtype=
|
|
80
|
-
spawn_time=jnp.zeros((h, w), dtype=
|
|
81
|
-
color=jnp.full((h, w, 3), 255, dtype=
|
|
82
|
-
generation=jnp.zeros((h, w), dtype=
|
|
83
|
-
state_params=jnp.zeros((h, w, num_params), dtype=
|
|
84
|
-
biome_id=jnp.full((h, w), -1, dtype=
|
|
86
|
+
object_id=jnp.zeros((h, w), dtype=ID_DTYPE),
|
|
87
|
+
respawn_timer=jnp.zeros((h, w), dtype=TIMER_DTYPE),
|
|
88
|
+
respawn_object_id=jnp.zeros((h, w), dtype=ID_DTYPE),
|
|
89
|
+
spawn_time=jnp.zeros((h, w), dtype=TIME_DTYPE),
|
|
90
|
+
color=jnp.full((h, w, 3), 255, dtype=COLOR_DTYPE),
|
|
91
|
+
generation=jnp.zeros((h, w), dtype=ID_DTYPE),
|
|
92
|
+
state_params=jnp.zeros((h, w, num_params), dtype=PARAM_DTYPE),
|
|
93
|
+
biome_id=jnp.full((h, w), -1, dtype=BIOME_ID_DTYPE),
|
|
85
94
|
)
|
|
86
95
|
|
|
87
96
|
|
|
@@ -128,7 +137,6 @@ class ForagaxEnv(environment.Environment):
|
|
|
128
137
|
biomes: Tuple[Biome, ...] = (Biome(object_frequencies=()),),
|
|
129
138
|
nowrap: bool = False,
|
|
130
139
|
deterministic_spawn: bool = False,
|
|
131
|
-
teleport_interval: Optional[int] = None,
|
|
132
140
|
observation_type: str = "object",
|
|
133
141
|
dynamic_biomes: bool = False,
|
|
134
142
|
biome_consumption_threshold: float = 0.9,
|
|
@@ -154,7 +162,6 @@ class ForagaxEnv(environment.Environment):
|
|
|
154
162
|
self.observation_type = observation_type
|
|
155
163
|
self.nowrap = nowrap
|
|
156
164
|
self.deterministic_spawn = deterministic_spawn
|
|
157
|
-
self.teleport_interval = teleport_interval
|
|
158
165
|
self.dynamic_biomes = dynamic_biomes
|
|
159
166
|
self.biome_consumption_threshold = biome_consumption_threshold
|
|
160
167
|
self.dynamic_biome_spawn_empty = dynamic_biome_spawn_empty
|
|
@@ -230,13 +237,13 @@ class ForagaxEnv(environment.Environment):
|
|
|
230
237
|
# Create mask for the biome
|
|
231
238
|
start = jax.lax.select(
|
|
232
239
|
self.biome_starts[i, 0] == -1,
|
|
233
|
-
jnp.array([0, 0]),
|
|
234
|
-
self.biome_starts[i],
|
|
240
|
+
jnp.array([0, 0], dtype=jnp.int32),
|
|
241
|
+
self.biome_starts[i].astype(jnp.int32),
|
|
235
242
|
)
|
|
236
243
|
stop = jax.lax.select(
|
|
237
244
|
self.biome_stops[i, 0] == -1,
|
|
238
|
-
jnp.array(self.size),
|
|
239
|
-
self.biome_stops[i],
|
|
245
|
+
jnp.array(self.size, dtype=jnp.int32),
|
|
246
|
+
self.biome_stops[i].astype(jnp.int32),
|
|
240
247
|
)
|
|
241
248
|
rows = jnp.arange(self.size[1])[:, None]
|
|
242
249
|
cols = jnp.arange(self.size[0])
|
|
@@ -275,12 +282,18 @@ class ForagaxEnv(environment.Environment):
|
|
|
275
282
|
# color_indices maps from object_id-1 to color_channel_index
|
|
276
283
|
self.object_to_color_map = color_indices
|
|
277
284
|
|
|
285
|
+
# Rendering constants
|
|
286
|
+
self.agent_color_jax = jnp.array(AGENT.color, dtype=jnp.uint8)
|
|
287
|
+
self.white_color_jax = jnp.array([255, 255, 255], dtype=jnp.uint8)
|
|
288
|
+
self.grid_color_jax = jnp.zeros(3, dtype=jnp.uint8)
|
|
289
|
+
|
|
278
290
|
@property
|
|
279
291
|
def default_params(self) -> EnvParams:
|
|
280
292
|
return EnvParams(
|
|
281
293
|
max_steps_in_episode=None,
|
|
282
294
|
)
|
|
283
295
|
|
|
296
|
+
@partial(jax.named_call, name="_place_timer")
|
|
284
297
|
def _place_timer(
|
|
285
298
|
self,
|
|
286
299
|
object_state: ObjectState,
|
|
@@ -304,13 +317,24 @@ class ForagaxEnv(environment.Environment):
|
|
|
304
317
|
Returns:
|
|
305
318
|
Updated object_state with timer placed
|
|
306
319
|
"""
|
|
320
|
+
# Ensure inputs match ObjectState dtypes
|
|
321
|
+
object_type = jnp.array(object_type, dtype=ID_DTYPE)
|
|
322
|
+
timer_val = jnp.array(timer_val, dtype=TIMER_DTYPE)
|
|
323
|
+
y = jnp.array(y, dtype=jnp.int32)
|
|
324
|
+
x = jnp.array(x, dtype=jnp.int32)
|
|
307
325
|
|
|
308
326
|
# Handle permanent removal (timer_val == 0)
|
|
309
327
|
def place_empty():
|
|
310
328
|
return object_state.replace(
|
|
311
|
-
object_id=object_state.object_id.at[y, x].set(
|
|
312
|
-
|
|
313
|
-
|
|
329
|
+
object_id=object_state.object_id.at[y, x].set(
|
|
330
|
+
jnp.array(0, dtype=ID_DTYPE)
|
|
331
|
+
),
|
|
332
|
+
respawn_timer=object_state.respawn_timer.at[y, x].set(
|
|
333
|
+
jnp.array(0, dtype=TIMER_DTYPE)
|
|
334
|
+
),
|
|
335
|
+
respawn_object_id=object_state.respawn_object_id.at[y, x].set(
|
|
336
|
+
jnp.array(0, dtype=ID_DTYPE)
|
|
337
|
+
),
|
|
314
338
|
)
|
|
315
339
|
|
|
316
340
|
# Handle timer placement
|
|
@@ -318,7 +342,9 @@ class ForagaxEnv(environment.Environment):
|
|
|
318
342
|
# Non-random: place at original position
|
|
319
343
|
def place_at_position():
|
|
320
344
|
return object_state.replace(
|
|
321
|
-
object_id=object_state.object_id.at[y, x].set(
|
|
345
|
+
object_id=object_state.object_id.at[y, x].set(
|
|
346
|
+
jnp.array(0, dtype=ID_DTYPE)
|
|
347
|
+
),
|
|
322
348
|
respawn_timer=object_state.respawn_timer.at[y, x].set(timer_val),
|
|
323
349
|
respawn_object_id=object_state.respawn_object_id.at[y, x].set(
|
|
324
350
|
object_type
|
|
@@ -328,9 +354,15 @@ class ForagaxEnv(environment.Environment):
|
|
|
328
354
|
# Random: place at random location in same biome
|
|
329
355
|
def place_randomly():
|
|
330
356
|
# Clear the collected object's position
|
|
331
|
-
new_object_id = object_state.object_id.at[y, x].set(
|
|
332
|
-
|
|
333
|
-
|
|
357
|
+
new_object_id = object_state.object_id.at[y, x].set(
|
|
358
|
+
jnp.array(0, dtype=ID_DTYPE)
|
|
359
|
+
)
|
|
360
|
+
new_respawn_timer = object_state.respawn_timer.at[y, x].set(
|
|
361
|
+
jnp.array(0, dtype=TIMER_DTYPE)
|
|
362
|
+
)
|
|
363
|
+
new_respawn_object_id = object_state.respawn_object_id.at[y, x].set(
|
|
364
|
+
jnp.array(0, dtype=ID_DTYPE)
|
|
365
|
+
)
|
|
334
366
|
|
|
335
367
|
# Find valid spawn locations in the same biome
|
|
336
368
|
biome_id = object_state.biome_id[y, x]
|
|
@@ -338,13 +370,15 @@ class ForagaxEnv(environment.Environment):
|
|
|
338
370
|
empty_mask = new_object_id == 0
|
|
339
371
|
no_timer_mask = new_respawn_timer == 0
|
|
340
372
|
valid_spawn_mask = biome_mask & empty_mask & no_timer_mask
|
|
341
|
-
num_valid_spawns = jnp.sum(valid_spawn_mask)
|
|
373
|
+
num_valid_spawns = jnp.sum(valid_spawn_mask, dtype=jnp.int32)
|
|
342
374
|
|
|
343
375
|
y_indices, x_indices = jnp.nonzero(
|
|
344
376
|
valid_spawn_mask, size=self.size[0] * self.size[1], fill_value=-1
|
|
345
377
|
)
|
|
346
378
|
valid_spawn_indices = jnp.stack([y_indices, x_indices], axis=1)
|
|
347
|
-
random_idx = jax.random.randint(
|
|
379
|
+
random_idx = jax.random.randint(
|
|
380
|
+
rand_key, (), jnp.array(0, dtype=jnp.int32), num_valid_spawns
|
|
381
|
+
)
|
|
348
382
|
new_spawn_pos = valid_spawn_indices[random_idx]
|
|
349
383
|
|
|
350
384
|
# Place timer at the new random position
|
|
@@ -365,17 +399,11 @@ class ForagaxEnv(environment.Environment):
|
|
|
365
399
|
|
|
366
400
|
return jax.lax.cond(timer_val == 0, place_empty, place_timer)
|
|
367
401
|
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
action: Union[int, float, jax.Array],
|
|
373
|
-
params: EnvParams,
|
|
374
|
-
) -> Tuple[jax.Array, EnvState, jax.Array, jax.Array, Dict[Any, Any]]:
|
|
375
|
-
"""Perform single timestep state transition."""
|
|
402
|
+
@partial(jax.named_call, name="move_agent")
|
|
403
|
+
def _move_agent(
|
|
404
|
+
self, state: EnvState, action: Union[int, float, jax.Array]
|
|
405
|
+
) -> Tuple[jax.Array, jax.Array]:
|
|
376
406
|
current_objects = state.object_state.object_id
|
|
377
|
-
|
|
378
|
-
# 1. UPDATE AGENT POSITION
|
|
379
407
|
direction = DIRECTIONS[action]
|
|
380
408
|
new_pos = state.pos + direction
|
|
381
409
|
|
|
@@ -388,45 +416,44 @@ class ForagaxEnv(environment.Environment):
|
|
|
388
416
|
|
|
389
417
|
# Check for blocking objects
|
|
390
418
|
obj_at_new_pos = current_objects[new_pos[1], new_pos[0]]
|
|
391
|
-
is_blocking = self.object_blocking[obj_at_new_pos]
|
|
392
|
-
pos =
|
|
419
|
+
is_blocking = self.object_blocking[obj_at_new_pos.astype(jnp.int32)]
|
|
420
|
+
pos = jnp.where(
|
|
421
|
+
is_blocking[..., None],
|
|
422
|
+
state.pos.astype(jnp.int32),
|
|
423
|
+
new_pos.astype(jnp.int32),
|
|
424
|
+
)
|
|
425
|
+
return pos, new_pos
|
|
393
426
|
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
# Find the index of the furthest biome center
|
|
405
|
-
furthest_idx = jnp.argmax(distances)
|
|
406
|
-
new_pos = self.biome_centers_jax[furthest_idx]
|
|
407
|
-
return new_pos
|
|
408
|
-
|
|
409
|
-
pos = jax.lax.cond(should_teleport, teleport_fn, lambda: pos)
|
|
410
|
-
|
|
411
|
-
# 2. HANDLE COLLISIONS AND REWARDS
|
|
412
|
-
obj_at_pos = current_objects[pos[1], pos[0]]
|
|
413
|
-
is_collectable = self.object_collectable[obj_at_pos]
|
|
414
|
-
should_collect = is_collectable & (obj_at_pos > 0)
|
|
415
|
-
|
|
416
|
-
# Handle digestion: add reward to buffer if collected
|
|
417
|
-
digestion_buffer = state.digestion_buffer
|
|
427
|
+
@partial(jax.named_call, name="compute_reward")
|
|
428
|
+
def _compute_reward(
|
|
429
|
+
self,
|
|
430
|
+
state: EnvState,
|
|
431
|
+
pos: jax.Array,
|
|
432
|
+
key: jax.Array,
|
|
433
|
+
should_collect: jax.Array,
|
|
434
|
+
digestion_buffer: jax.Array,
|
|
435
|
+
obj_at_pos: jax.Array,
|
|
436
|
+
) -> Tuple[jax.Array, jax.Array]:
|
|
418
437
|
key, reward_subkey = jax.random.split(key)
|
|
419
438
|
|
|
420
439
|
object_params = state.object_state.state_params[pos[1], pos[0]]
|
|
421
440
|
object_reward = jax.lax.switch(
|
|
422
|
-
obj_at_pos
|
|
441
|
+
obj_at_pos.astype(jnp.int32),
|
|
442
|
+
self.reward_fns,
|
|
443
|
+
state.time,
|
|
444
|
+
reward_subkey,
|
|
445
|
+
object_params.astype(jnp.float32),
|
|
423
446
|
)
|
|
424
447
|
|
|
425
448
|
key, digestion_subkey = jax.random.split(key)
|
|
426
449
|
reward_delay = jax.lax.switch(
|
|
427
450
|
obj_at_pos, self.reward_delay_fns, state.time, digestion_subkey
|
|
428
451
|
)
|
|
429
|
-
reward = jnp.where(
|
|
452
|
+
reward = jnp.where(
|
|
453
|
+
should_collect & (reward_delay == jnp.array(0, dtype=jnp.int32)),
|
|
454
|
+
object_reward,
|
|
455
|
+
0.0,
|
|
456
|
+
)
|
|
430
457
|
if self.max_reward_delay > 0:
|
|
431
458
|
# Add delayed rewards to buffer
|
|
432
459
|
digestion_buffer = jax.lax.cond(
|
|
@@ -440,20 +467,33 @@ class ForagaxEnv(environment.Environment):
|
|
|
440
467
|
current_index = state.time % self.max_reward_delay
|
|
441
468
|
reward += digestion_buffer[current_index]
|
|
442
469
|
digestion_buffer = digestion_buffer.at[current_index].set(0.0)
|
|
470
|
+
return reward, digestion_buffer
|
|
443
471
|
|
|
472
|
+
@partial(jax.named_call, name="respawn_logic")
|
|
473
|
+
def _respawn_logic(
|
|
474
|
+
self,
|
|
475
|
+
state: EnvState,
|
|
476
|
+
pos: jax.Array,
|
|
477
|
+
key: jax.Array,
|
|
478
|
+
current_objects: jax.Array,
|
|
479
|
+
is_collectable: jax.Array,
|
|
480
|
+
obj_at_pos: jax.Array,
|
|
481
|
+
) -> ObjectState:
|
|
444
482
|
# 3. HANDLE OBJECT COLLECTION AND RESPAWNING
|
|
445
483
|
key, regen_subkey, rand_key = jax.random.split(key, 3)
|
|
446
484
|
|
|
447
485
|
# Decrement respawn timers
|
|
448
|
-
has_timer = state.object_state.respawn_timer > 0
|
|
486
|
+
has_timer = state.object_state.respawn_timer > jnp.array(0, dtype=TIMER_DTYPE)
|
|
449
487
|
new_respawn_timer = jnp.where(
|
|
450
488
|
has_timer,
|
|
451
|
-
state.object_state.respawn_timer - 1,
|
|
489
|
+
state.object_state.respawn_timer - jnp.array(1, dtype=TIMER_DTYPE),
|
|
452
490
|
state.object_state.respawn_timer,
|
|
453
491
|
)
|
|
454
492
|
|
|
455
493
|
# Track which cells have timers that just reached 0
|
|
456
|
-
just_respawned = has_timer & (
|
|
494
|
+
just_respawned = has_timer & (
|
|
495
|
+
new_respawn_timer == jnp.array(0, dtype=TIMER_DTYPE)
|
|
496
|
+
)
|
|
457
497
|
|
|
458
498
|
# Respawn objects where timer reached 0
|
|
459
499
|
new_object_id = jnp.where(
|
|
@@ -465,7 +505,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
465
505
|
# Clear respawn_object_id for cells that just respawned
|
|
466
506
|
new_respawn_object_id = jnp.where(
|
|
467
507
|
just_respawned,
|
|
468
|
-
0,
|
|
508
|
+
jnp.array(0, dtype=ID_DTYPE),
|
|
469
509
|
state.object_state.respawn_object_id,
|
|
470
510
|
)
|
|
471
511
|
|
|
@@ -478,16 +518,19 @@ class ForagaxEnv(environment.Environment):
|
|
|
478
518
|
regen_delay = jax.lax.switch(
|
|
479
519
|
obj_at_pos, self.regen_delay_fns, state.time, regen_subkey
|
|
480
520
|
)
|
|
521
|
+
# Cast timer_countdown to match ObjectState.respawn_timer dtype
|
|
481
522
|
timer_countdown = jax.lax.cond(
|
|
482
523
|
regen_delay == jnp.iinfo(jnp.int32).max,
|
|
483
|
-
lambda: 0, # No timer (permanent removal)
|
|
484
|
-
lambda: regen_delay + 1, # Timer countdown
|
|
524
|
+
lambda: jnp.array(0, dtype=TIMER_DTYPE), # No timer (permanent removal)
|
|
525
|
+
lambda: (regen_delay + 1).astype(TIMER_DTYPE), # Timer countdown
|
|
485
526
|
)
|
|
486
527
|
|
|
487
528
|
# If collected, replace object with timer; otherwise, keep it
|
|
488
529
|
val_at_pos = current_objects[pos[1], pos[0]]
|
|
489
530
|
# Use original should_collect for consumption tracking
|
|
490
|
-
should_collect_now = is_collectable & (
|
|
531
|
+
should_collect_now = is_collectable & (
|
|
532
|
+
val_at_pos > jnp.array(0, dtype=ID_DTYPE)
|
|
533
|
+
)
|
|
491
534
|
|
|
492
535
|
# Create updated object state with new respawn_timer, object_id, and spawn_time
|
|
493
536
|
object_state = state.object_state.replace(
|
|
@@ -511,12 +554,17 @@ class ForagaxEnv(environment.Environment):
|
|
|
511
554
|
),
|
|
512
555
|
lambda: object_state,
|
|
513
556
|
)
|
|
557
|
+
return object_state
|
|
514
558
|
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
559
|
+
@partial(jax.named_call, name="dynamic_biomes")
|
|
560
|
+
def _dynamic_biomes(
|
|
561
|
+
self,
|
|
562
|
+
state: EnvState,
|
|
563
|
+
pos: jax.Array,
|
|
564
|
+
key: jax.Array,
|
|
565
|
+
object_state: ObjectState,
|
|
566
|
+
should_collect: jax.Array,
|
|
567
|
+
) -> Tuple[ObjectState, BiomeState]:
|
|
520
568
|
if self.dynamic_biomes:
|
|
521
569
|
# Update consumption count if an object was collected
|
|
522
570
|
# Only count if the object belongs to the current generation of its biome
|
|
@@ -528,7 +576,9 @@ class ForagaxEnv(environment.Environment):
|
|
|
528
576
|
biome_consumption_count = state.biome_state.consumption_count
|
|
529
577
|
biome_consumption_count = jax.lax.cond(
|
|
530
578
|
should_collect & is_current_generation,
|
|
531
|
-
lambda: biome_consumption_count.at[collected_biome_id].add(
|
|
579
|
+
lambda: biome_consumption_count.at[collected_biome_id].add(
|
|
580
|
+
jnp.array(1, dtype=ID_DTYPE)
|
|
581
|
+
),
|
|
532
582
|
lambda: biome_consumption_count,
|
|
533
583
|
)
|
|
534
584
|
|
|
@@ -545,8 +595,73 @@ class ForagaxEnv(environment.Environment):
|
|
|
545
595
|
state.time,
|
|
546
596
|
respawn_key,
|
|
547
597
|
)
|
|
598
|
+
return object_state, biome_state
|
|
548
599
|
else:
|
|
549
|
-
|
|
600
|
+
return object_state, state.biome_state
|
|
601
|
+
|
|
602
|
+
@partial(jax.named_call, name="reward_grid")
|
|
603
|
+
def _reward_grid(self, state: EnvState, object_state: ObjectState) -> jax.Array:
|
|
604
|
+
# Compute reward at each grid position
|
|
605
|
+
fixed_key = jax.random.key(0) # Fixed key for deterministic reward computation
|
|
606
|
+
|
|
607
|
+
def compute_reward(obj_id, params):
|
|
608
|
+
return jax.lax.cond(
|
|
609
|
+
obj_id > jnp.array(0, dtype=ID_DTYPE),
|
|
610
|
+
lambda: jax.lax.switch(
|
|
611
|
+
obj_id.astype(jnp.int32),
|
|
612
|
+
self.reward_fns,
|
|
613
|
+
state.time,
|
|
614
|
+
fixed_key,
|
|
615
|
+
params.astype(jnp.float32),
|
|
616
|
+
),
|
|
617
|
+
lambda: 0.0,
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
reward_grid = jax.vmap(jax.vmap(compute_reward))(
|
|
621
|
+
object_state.object_id.astype(ID_DTYPE),
|
|
622
|
+
object_state.state_params.astype(PARAM_DTYPE),
|
|
623
|
+
)
|
|
624
|
+
return reward_grid
|
|
625
|
+
|
|
626
|
+
def step_env(
|
|
627
|
+
self,
|
|
628
|
+
key: jax.Array,
|
|
629
|
+
state: EnvState,
|
|
630
|
+
action: Union[int, float, jax.Array],
|
|
631
|
+
params: EnvParams,
|
|
632
|
+
) -> Tuple[jax.Array, EnvState, jax.Array, jax.Array, Dict[Any, Any]]:
|
|
633
|
+
"""Perform single timestep state transition."""
|
|
634
|
+
current_objects = state.object_state.object_id
|
|
635
|
+
pos, new_pos = self._move_agent(state, action)
|
|
636
|
+
|
|
637
|
+
with jax.named_scope("compute_reward"):
|
|
638
|
+
# 2. HANDLE COLLISIONS AND REWARDS
|
|
639
|
+
obj_at_pos = current_objects[pos[1], pos[0]]
|
|
640
|
+
is_collectable = self.object_collectable[obj_at_pos]
|
|
641
|
+
should_collect = is_collectable & (
|
|
642
|
+
obj_at_pos > jnp.array(0, dtype=ID_DTYPE)
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
# Handle digestion: add reward to buffer if collected
|
|
646
|
+
digestion_buffer = state.digestion_buffer
|
|
647
|
+
key, reward_subkey = jax.random.split(key)
|
|
648
|
+
|
|
649
|
+
reward, digestion_buffer = self._compute_reward(
|
|
650
|
+
state, pos, key, should_collect, digestion_buffer, obj_at_pos
|
|
651
|
+
)
|
|
652
|
+
|
|
653
|
+
object_state = self._respawn_logic(
|
|
654
|
+
state, pos, key, current_objects, is_collectable, obj_at_pos
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
# 3.5. HANDLE OBJECT EXPIRY
|
|
658
|
+
# Only process expiry if there are objects that can expire
|
|
659
|
+
key, object_state = self.expire_objects(key, state, object_state)
|
|
660
|
+
|
|
661
|
+
# 3.6. HANDLE DYNAMIC BIOME CONSUMPTION AND RESPAWNING
|
|
662
|
+
object_state, biome_state = self._dynamic_biomes(
|
|
663
|
+
state, pos, key, object_state, should_collect
|
|
664
|
+
)
|
|
550
665
|
|
|
551
666
|
info = {"discount": self.discount(state, params)}
|
|
552
667
|
temperatures = jnp.zeros(len(self.objects))
|
|
@@ -557,33 +672,40 @@ class ForagaxEnv(environment.Environment):
|
|
|
557
672
|
)
|
|
558
673
|
info["temperatures"] = temperatures
|
|
559
674
|
info["biome_id"] = object_state.biome_id[pos[1], pos[0]]
|
|
560
|
-
info["object_collected_id"] =
|
|
675
|
+
info["object_collected_id"] = jnp.where(
|
|
676
|
+
should_collect,
|
|
677
|
+
obj_at_pos.astype(ID_DTYPE),
|
|
678
|
+
jnp.array(-1, dtype=ID_DTYPE),
|
|
679
|
+
)
|
|
561
680
|
|
|
562
681
|
# 4. UPDATE STATE
|
|
682
|
+
# Ensure all fields have canonical dtypes for consistency (e.g., for gymnax step selection)
|
|
683
|
+
object_state = object_state.replace(
|
|
684
|
+
object_id=object_state.object_id.astype(ID_DTYPE),
|
|
685
|
+
respawn_timer=object_state.respawn_timer.astype(TIMER_DTYPE),
|
|
686
|
+
respawn_object_id=object_state.respawn_object_id.astype(ID_DTYPE),
|
|
687
|
+
spawn_time=object_state.spawn_time.astype(TIME_DTYPE),
|
|
688
|
+
color=object_state.color.astype(COLOR_DTYPE),
|
|
689
|
+
generation=object_state.generation.astype(ID_DTYPE),
|
|
690
|
+
state_params=object_state.state_params.astype(PARAM_DTYPE),
|
|
691
|
+
biome_id=object_state.biome_id.astype(BIOME_ID_DTYPE),
|
|
692
|
+
)
|
|
693
|
+
biome_state = biome_state.replace(
|
|
694
|
+
consumption_count=biome_state.consumption_count.astype(ID_DTYPE),
|
|
695
|
+
total_objects=biome_state.total_objects.astype(ID_DTYPE),
|
|
696
|
+
generation=biome_state.generation.astype(ID_DTYPE),
|
|
697
|
+
)
|
|
698
|
+
|
|
563
699
|
state = EnvState(
|
|
564
|
-
pos=pos,
|
|
565
|
-
time=state.time + 1,
|
|
566
|
-
digestion_buffer=digestion_buffer,
|
|
700
|
+
pos=pos.astype(jnp.int32),
|
|
701
|
+
time=jnp.array(state.time + 1, dtype=jnp.int32),
|
|
702
|
+
digestion_buffer=digestion_buffer.astype(REWARD_DTYPE),
|
|
567
703
|
object_state=object_state,
|
|
568
704
|
biome_state=biome_state,
|
|
569
705
|
)
|
|
570
706
|
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
def compute_reward(obj_id, params):
|
|
575
|
-
return jax.lax.cond(
|
|
576
|
-
obj_id > 0,
|
|
577
|
-
lambda: jax.lax.switch(
|
|
578
|
-
obj_id, self.reward_fns, state.time, fixed_key, params
|
|
579
|
-
),
|
|
580
|
-
lambda: 0.0,
|
|
581
|
-
)
|
|
582
|
-
|
|
583
|
-
reward_grid = jax.vmap(jax.vmap(compute_reward))(
|
|
584
|
-
object_state.object_id, object_state.state_params
|
|
585
|
-
)
|
|
586
|
-
info["rewards"] = reward_grid
|
|
707
|
+
reward_grid = self._reward_grid(state, object_state)
|
|
708
|
+
info["rewards"] = reward_grid.astype(jnp.float16)
|
|
587
709
|
|
|
588
710
|
done = self.is_terminal(state, params)
|
|
589
711
|
return (
|
|
@@ -594,6 +716,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
594
716
|
info,
|
|
595
717
|
)
|
|
596
718
|
|
|
719
|
+
@partial(jax.named_call, name="expire_objects")
|
|
597
720
|
def expire_objects(
|
|
598
721
|
self, key, state, object_state: ObjectState
|
|
599
722
|
) -> Tuple[jax.Array, ObjectState]:
|
|
@@ -610,8 +733,8 @@ class ForagaxEnv(environment.Environment):
|
|
|
610
733
|
# Check if object should expire (age >= expiry_time and expiry_time >= 0)
|
|
611
734
|
should_expire = (
|
|
612
735
|
(object_ages >= expiry_times)
|
|
613
|
-
& (expiry_times >= 0)
|
|
614
|
-
& (current_objects_for_expiry > 0)
|
|
736
|
+
& (expiry_times >= jnp.array(0, dtype=TIME_DTYPE))
|
|
737
|
+
& (current_objects_for_expiry > jnp.array(0, dtype=ID_DTYPE))
|
|
615
738
|
)
|
|
616
739
|
|
|
617
740
|
# Only process expiry if there are actually objects to expire
|
|
@@ -646,8 +769,8 @@ class ForagaxEnv(environment.Environment):
|
|
|
646
769
|
)
|
|
647
770
|
timer_countdown = jax.lax.cond(
|
|
648
771
|
exp_delay == jnp.iinfo(jnp.int32).max,
|
|
649
|
-
lambda: 0,
|
|
650
|
-
lambda: exp_delay + 1,
|
|
772
|
+
lambda: jnp.array(0, dtype=TIMER_DTYPE),
|
|
773
|
+
lambda: (exp_delay + 1).astype(TIMER_DTYPE),
|
|
651
774
|
)
|
|
652
775
|
|
|
653
776
|
respawn_random = self.object_random_respawn[obj_id]
|
|
@@ -707,147 +830,189 @@ class ForagaxEnv(environment.Environment):
|
|
|
707
830
|
biome_state.consumption_count >= self.biome_consumption_threshold
|
|
708
831
|
)
|
|
709
832
|
|
|
710
|
-
|
|
711
|
-
key, subkey = jax.random.split(key)
|
|
712
|
-
biome_keys = jax.random.split(subkey, num_biomes)
|
|
833
|
+
any_respawn = jnp.any(should_respawn)
|
|
713
834
|
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
835
|
+
def do_respawn(args):
|
|
836
|
+
(
|
|
837
|
+
object_state,
|
|
838
|
+
biome_state,
|
|
839
|
+
should_respawn,
|
|
840
|
+
key,
|
|
841
|
+
) = args
|
|
842
|
+
# Split key for all biomes in parallel
|
|
843
|
+
key, subkey = jax.random.split(key)
|
|
844
|
+
biome_keys = jax.random.split(subkey, num_biomes)
|
|
845
|
+
|
|
846
|
+
# Compute all new spawns in parallel using vmap for random, switch for deterministic
|
|
847
|
+
if self.deterministic_spawn:
|
|
848
|
+
# Use switch to dispatch to concrete biome spawns for deterministic
|
|
849
|
+
def make_spawn_fn(biome_idx):
|
|
850
|
+
def spawn_fn(key):
|
|
851
|
+
return self._spawn_biome_objects(
|
|
852
|
+
biome_idx, key, deterministic=True
|
|
853
|
+
)
|
|
720
854
|
|
|
721
|
-
|
|
855
|
+
return spawn_fn
|
|
722
856
|
|
|
723
|
-
|
|
857
|
+
spawn_fns = [make_spawn_fn(idx) for idx in range(num_biomes)]
|
|
724
858
|
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
859
|
+
# Apply switch for each biome
|
|
860
|
+
all_new_objects_list = []
|
|
861
|
+
all_new_colors_list = []
|
|
862
|
+
all_new_params_list = []
|
|
863
|
+
for i in range(num_biomes):
|
|
864
|
+
obj, col, par = jax.lax.switch(i, spawn_fns, biome_keys[i])
|
|
865
|
+
all_new_objects_list.append(obj)
|
|
866
|
+
all_new_colors_list.append(col)
|
|
867
|
+
all_new_params_list.append(par)
|
|
868
|
+
|
|
869
|
+
all_new_objects = jnp.stack(all_new_objects_list)
|
|
870
|
+
all_new_colors = jnp.stack(all_new_colors_list)
|
|
871
|
+
all_new_params = jnp.stack(all_new_params_list)
|
|
872
|
+
else:
|
|
873
|
+
# Random spawn works with vmap
|
|
874
|
+
all_new_objects, all_new_colors, all_new_params = jax.vmap(
|
|
875
|
+
lambda i, k: self._spawn_biome_objects(i, k, deterministic=False)
|
|
876
|
+
)(jnp.arange(num_biomes), biome_keys)
|
|
877
|
+
|
|
878
|
+
# Initialize updated grids
|
|
879
|
+
new_obj_id = object_state.object_id
|
|
880
|
+
new_color = object_state.color
|
|
881
|
+
new_params = object_state.state_params
|
|
882
|
+
new_spawn = object_state.spawn_time
|
|
883
|
+
new_gen = object_state.generation
|
|
884
|
+
|
|
885
|
+
# Update biome state
|
|
886
|
+
new_consumption_count = jnp.where(
|
|
887
|
+
should_respawn,
|
|
888
|
+
jnp.array(0, dtype=ID_DTYPE),
|
|
889
|
+
biome_state.consumption_count,
|
|
890
|
+
)
|
|
891
|
+
new_generation = biome_state.generation + should_respawn.astype(ID_DTYPE)
|
|
756
892
|
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
893
|
+
# Compute new total objects for respawning biomes
|
|
894
|
+
def count_objects(i):
|
|
895
|
+
return jnp.sum(
|
|
896
|
+
(all_new_objects[i] > 0) & self.biome_masks_array[i], dtype=ID_DTYPE
|
|
897
|
+
)
|
|
760
898
|
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
899
|
+
new_object_counts = jax.vmap(count_objects)(jnp.arange(num_biomes))
|
|
900
|
+
new_total_objects = jnp.where(
|
|
901
|
+
should_respawn, new_object_counts, biome_state.total_objects
|
|
902
|
+
)
|
|
765
903
|
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
904
|
+
new_biome_state = BiomeState(
|
|
905
|
+
consumption_count=new_consumption_count,
|
|
906
|
+
total_objects=new_total_objects,
|
|
907
|
+
generation=new_generation,
|
|
908
|
+
)
|
|
771
909
|
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
910
|
+
# Update grids for respawning biomes
|
|
911
|
+
for i in range(num_biomes):
|
|
912
|
+
biome_mask = self.biome_masks_array[i]
|
|
913
|
+
new_gen_value = new_biome_state.generation[i]
|
|
776
914
|
|
|
777
|
-
|
|
778
|
-
|
|
915
|
+
# Update mask: biome area AND needs respawn
|
|
916
|
+
should_update = biome_mask & should_respawn[i][..., None]
|
|
779
917
|
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
new_spawn_valid = all_new_objects[i] > 0
|
|
918
|
+
# 1. Merge: Overwrite with new objects if present, otherwise keep existing
|
|
919
|
+
new_spawn_valid = all_new_objects[i] > 0
|
|
783
920
|
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
921
|
+
merged_objs = jnp.where(new_spawn_valid, all_new_objects[i], new_obj_id)
|
|
922
|
+
merged_colors = jnp.where(
|
|
923
|
+
new_spawn_valid[..., None], all_new_colors[i], new_color
|
|
924
|
+
)
|
|
925
|
+
merged_params = jnp.where(
|
|
926
|
+
new_spawn_valid[..., None], all_new_params[i], new_params
|
|
927
|
+
)
|
|
791
928
|
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
merged_spawn = jnp.where(new_spawn_valid, current_time, new_spawn)
|
|
929
|
+
merged_gen = jnp.where(new_spawn_valid, new_gen_value, new_gen)
|
|
930
|
+
merged_spawn = jnp.where(new_spawn_valid, current_time, new_spawn)
|
|
795
931
|
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
932
|
+
if self.dynamic_biome_spawn_empty > 0:
|
|
933
|
+
key, dropout_key = jax.random.split(key)
|
|
934
|
+
keep_mask = jax.random.bernoulli(
|
|
935
|
+
dropout_key,
|
|
936
|
+
1.0 - self.dynamic_biome_spawn_empty,
|
|
937
|
+
merged_objs.shape,
|
|
938
|
+
)
|
|
939
|
+
dropout_mask = should_update & keep_mask
|
|
940
|
+
final_objs = jnp.where(
|
|
941
|
+
dropout_mask, merged_objs, jnp.array(0, dtype=ID_DTYPE)
|
|
942
|
+
)
|
|
943
|
+
final_colors = jnp.where(
|
|
944
|
+
dropout_mask[..., None],
|
|
945
|
+
merged_colors,
|
|
946
|
+
jnp.array(0, dtype=COLOR_DTYPE),
|
|
947
|
+
)
|
|
948
|
+
final_params = jnp.where(
|
|
949
|
+
dropout_mask[..., None],
|
|
950
|
+
merged_params,
|
|
951
|
+
jnp.array(0, dtype=PARAM_DTYPE),
|
|
952
|
+
)
|
|
953
|
+
final_gen = jnp.where(
|
|
954
|
+
dropout_mask, merged_gen, jnp.array(0, dtype=ID_DTYPE)
|
|
955
|
+
)
|
|
956
|
+
final_spawn = jnp.where(
|
|
957
|
+
dropout_mask, merged_spawn, jnp.array(0, dtype=TIME_DTYPE)
|
|
958
|
+
)
|
|
959
|
+
else:
|
|
960
|
+
final_objs = merged_objs
|
|
961
|
+
final_colors = merged_colors
|
|
962
|
+
final_params = merged_params
|
|
963
|
+
final_gen = merged_gen
|
|
964
|
+
final_spawn = merged_spawn
|
|
965
|
+
|
|
966
|
+
new_obj_id = jnp.where(should_update, final_objs, new_obj_id)
|
|
967
|
+
new_color = jnp.where(should_update[..., None], final_colors, new_color)
|
|
968
|
+
new_params = jnp.where(
|
|
969
|
+
should_update[..., None], final_params, new_params
|
|
802
970
|
)
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
# Apply dropout only to the merged result and associated metadata
|
|
806
|
-
final_objs = jnp.where(dropout_mask, merged_objs, 0)
|
|
807
|
-
final_colors = jnp.where(dropout_mask[..., None], merged_colors, 0)
|
|
808
|
-
final_params = jnp.where(dropout_mask[..., None], merged_params, 0)
|
|
809
|
-
final_gen = jnp.where(dropout_mask, merged_gen, 0)
|
|
810
|
-
final_spawn = jnp.where(dropout_mask, merged_spawn, 0)
|
|
811
|
-
else:
|
|
812
|
-
final_objs = merged_objs
|
|
813
|
-
final_colors = merged_colors
|
|
814
|
-
final_params = merged_params
|
|
815
|
-
final_gen = merged_gen
|
|
816
|
-
final_spawn = merged_spawn
|
|
817
|
-
|
|
818
|
-
# 3. Write back: Only update where should_update is true
|
|
819
|
-
new_obj_id = jnp.where(should_update, final_objs, new_obj_id)
|
|
820
|
-
new_color = jnp.where(should_update[..., None], final_colors, new_color)
|
|
821
|
-
new_params = jnp.where(should_update[..., None], final_params, new_params)
|
|
822
|
-
new_gen = jnp.where(should_update, final_gen, new_gen)
|
|
823
|
-
new_spawn = jnp.where(should_update, final_spawn, new_spawn)
|
|
824
|
-
|
|
825
|
-
# Clear timers in respawning biomes
|
|
826
|
-
new_respawn_timer = object_state.respawn_timer
|
|
827
|
-
new_respawn_object_id = object_state.respawn_object_id
|
|
828
|
-
for i in range(num_biomes):
|
|
829
|
-
biome_mask = self.biome_masks_array[i]
|
|
830
|
-
should_clear = biome_mask & should_respawn[i][..., None]
|
|
831
|
-
new_respawn_timer = jnp.where(should_clear, 0, new_respawn_timer)
|
|
832
|
-
new_respawn_object_id = jnp.where(should_clear, 0, new_respawn_object_id)
|
|
971
|
+
new_gen = jnp.where(should_update, final_gen, new_gen)
|
|
972
|
+
new_spawn = jnp.where(should_update, final_spawn, new_spawn)
|
|
833
973
|
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
974
|
+
# Clear timers in respawning biomes
|
|
975
|
+
new_respawn_timer = object_state.respawn_timer
|
|
976
|
+
new_respawn_object_id = object_state.respawn_object_id
|
|
977
|
+
for i in range(num_biomes):
|
|
978
|
+
biome_mask = self.biome_masks_array[i]
|
|
979
|
+
should_clear = biome_mask & should_respawn[i][..., None]
|
|
980
|
+
new_respawn_timer = jnp.where(
|
|
981
|
+
should_clear, jnp.array(0, dtype=TIMER_DTYPE), new_respawn_timer
|
|
982
|
+
)
|
|
983
|
+
new_respawn_object_id = jnp.where(
|
|
984
|
+
should_clear, jnp.array(0, dtype=ID_DTYPE), new_respawn_object_id
|
|
985
|
+
)
|
|
986
|
+
|
|
987
|
+
new_object_state = object_state.replace(
|
|
988
|
+
object_id=new_obj_id,
|
|
989
|
+
respawn_timer=new_respawn_timer,
|
|
990
|
+
respawn_object_id=new_respawn_object_id,
|
|
991
|
+
color=new_color,
|
|
992
|
+
state_params=new_params,
|
|
993
|
+
generation=new_gen,
|
|
994
|
+
spawn_time=new_spawn,
|
|
995
|
+
)
|
|
996
|
+
return new_object_state, new_biome_state, key
|
|
997
|
+
|
|
998
|
+
def no_respawn(args):
|
|
999
|
+
object_state, biome_state, _, key = args
|
|
1000
|
+
return object_state, biome_state, key
|
|
1001
|
+
|
|
1002
|
+
object_state, biome_state, key = jax.lax.cond(
|
|
1003
|
+
any_respawn,
|
|
1004
|
+
do_respawn,
|
|
1005
|
+
no_respawn,
|
|
1006
|
+
(object_state, biome_state, should_respawn, key),
|
|
842
1007
|
)
|
|
843
1008
|
|
|
844
|
-
return object_state,
|
|
1009
|
+
return object_state, biome_state, key
|
|
845
1010
|
|
|
846
1011
|
def reset_env(
|
|
847
1012
|
self, key: jax.Array, params: EnvParams
|
|
848
1013
|
) -> Tuple[jax.Array, EnvState]:
|
|
849
1014
|
"""Reset environment state."""
|
|
850
|
-
num_object_params =
|
|
1015
|
+
num_object_params = 3 + 2 * self.num_fourier_terms
|
|
851
1016
|
object_state = ObjectState.create_empty(self.size, num_object_params)
|
|
852
1017
|
|
|
853
1018
|
key, iter_key = jax.random.split(key)
|
|
@@ -859,7 +1024,9 @@ class ForagaxEnv(environment.Environment):
|
|
|
859
1024
|
|
|
860
1025
|
# Set biome_id
|
|
861
1026
|
object_state = object_state.replace(
|
|
862
|
-
biome_id=jnp.where(
|
|
1027
|
+
biome_id=jnp.where(
|
|
1028
|
+
mask, jnp.array(i, dtype=BIOME_ID_DTYPE), object_state.biome_id
|
|
1029
|
+
)
|
|
863
1030
|
)
|
|
864
1031
|
|
|
865
1032
|
# Use unified spawn method
|
|
@@ -881,31 +1048,45 @@ class ForagaxEnv(environment.Environment):
|
|
|
881
1048
|
|
|
882
1049
|
# Initialize biome consumption tracking
|
|
883
1050
|
num_biomes = self.biome_object_frequencies.shape[0]
|
|
884
|
-
biome_consumption_count = jnp.zeros(num_biomes, dtype=
|
|
885
|
-
biome_total_objects = jnp.zeros(num_biomes, dtype=
|
|
1051
|
+
biome_consumption_count = jnp.zeros(num_biomes, dtype=ID_DTYPE)
|
|
1052
|
+
biome_total_objects = jnp.zeros(num_biomes, dtype=ID_DTYPE)
|
|
886
1053
|
|
|
887
1054
|
# Count objects in each biome
|
|
888
1055
|
for i in range(num_biomes):
|
|
889
1056
|
mask = self.biome_masks[i]
|
|
890
1057
|
# Count non-empty objects (object_id > 0)
|
|
891
|
-
total = jnp.sum((object_state.object_id > 0) & mask)
|
|
1058
|
+
total = jnp.sum((object_state.object_id > 0) & mask, dtype=ID_DTYPE)
|
|
892
1059
|
biome_total_objects = biome_total_objects.at[i].set(total)
|
|
893
1060
|
|
|
894
|
-
biome_generation = jnp.zeros(num_biomes, dtype=
|
|
1061
|
+
biome_generation = jnp.zeros(num_biomes, dtype=ID_DTYPE)
|
|
1062
|
+
|
|
1063
|
+
# Final state cleanup to ensure type consistency
|
|
1064
|
+
object_state = object_state.replace(
|
|
1065
|
+
object_id=object_state.object_id.astype(ID_DTYPE),
|
|
1066
|
+
respawn_timer=object_state.respawn_timer.astype(TIMER_DTYPE),
|
|
1067
|
+
respawn_object_id=object_state.respawn_object_id.astype(ID_DTYPE),
|
|
1068
|
+
spawn_time=object_state.spawn_time.astype(TIME_DTYPE),
|
|
1069
|
+
color=object_state.color.astype(COLOR_DTYPE),
|
|
1070
|
+
generation=object_state.generation.astype(ID_DTYPE),
|
|
1071
|
+
state_params=object_state.state_params.astype(PARAM_DTYPE),
|
|
1072
|
+
biome_id=object_state.biome_id.astype(BIOME_ID_DTYPE),
|
|
1073
|
+
)
|
|
895
1074
|
|
|
896
1075
|
state = EnvState(
|
|
897
|
-
pos=agent_pos,
|
|
898
|
-
time=0,
|
|
899
|
-
digestion_buffer=jnp.zeros((self.max_reward_delay,)),
|
|
1076
|
+
pos=agent_pos.astype(jnp.int32),
|
|
1077
|
+
time=jnp.array(0, dtype=jnp.int32),
|
|
1078
|
+
digestion_buffer=jnp.zeros((self.max_reward_delay,), dtype=REWARD_DTYPE),
|
|
900
1079
|
object_state=object_state,
|
|
901
1080
|
biome_state=BiomeState(
|
|
902
|
-
consumption_count=biome_consumption_count,
|
|
903
|
-
total_objects=biome_total_objects,
|
|
904
|
-
generation=biome_generation,
|
|
1081
|
+
consumption_count=biome_consumption_count.astype(ID_DTYPE),
|
|
1082
|
+
total_objects=biome_total_objects.astype(ID_DTYPE),
|
|
1083
|
+
generation=biome_generation.astype(ID_DTYPE),
|
|
905
1084
|
),
|
|
906
1085
|
)
|
|
907
1086
|
|
|
908
|
-
return
|
|
1087
|
+
return jax.lax.stop_gradient(
|
|
1088
|
+
self.get_obs(state, params)
|
|
1089
|
+
), jax.lax.stop_gradient(state)
|
|
909
1090
|
|
|
910
1091
|
def _spawn_biome_objects(
|
|
911
1092
|
self,
|
|
@@ -937,16 +1118,17 @@ class ForagaxEnv(environment.Environment):
|
|
|
937
1118
|
biome_size = int(self.biome_sizes[biome_idx])
|
|
938
1119
|
|
|
939
1120
|
grid = jnp.linspace(0, 1, biome_size, endpoint=False)
|
|
940
|
-
biome_objects_flat =
|
|
941
|
-
|
|
942
|
-
|
|
1121
|
+
biome_objects_flat = (
|
|
1122
|
+
len(biome_freqs)
|
|
1123
|
+
- jnp.searchsorted(jnp.cumsum(biome_freqs[::-1]), grid, side="right")
|
|
1124
|
+
).astype(ID_DTYPE)
|
|
943
1125
|
biome_objects_flat = jax.random.permutation(spawn_key, biome_objects_flat)
|
|
944
1126
|
|
|
945
1127
|
# Reshape to match biome dimensions (use concrete dimensions)
|
|
946
1128
|
biome_objects = biome_objects_flat.reshape(biome_height, biome_width)
|
|
947
1129
|
|
|
948
1130
|
# Place in full grid using slicing with static bounds
|
|
949
|
-
object_grid = jnp.zeros((self.size[1], self.size[0]), dtype=
|
|
1131
|
+
object_grid = jnp.zeros((self.size[1], self.size[0]), dtype=ID_DTYPE)
|
|
950
1132
|
object_grid = object_grid.at[
|
|
951
1133
|
biome_start[1] : biome_stop[1], biome_start[0] : biome_stop[0]
|
|
952
1134
|
].set(biome_objects)
|
|
@@ -960,10 +1142,10 @@ class ForagaxEnv(environment.Environment):
|
|
|
960
1142
|
)
|
|
961
1143
|
object_grid = (
|
|
962
1144
|
jnp.searchsorted(cumulative_freqs, grid_rand, side="right") - 1
|
|
963
|
-
)
|
|
1145
|
+
).astype(ID_DTYPE)
|
|
964
1146
|
|
|
965
1147
|
# Initialize color grid
|
|
966
|
-
color_grid = jnp.full((self.size[1], self.size[0], 3), 255, dtype=
|
|
1148
|
+
color_grid = jnp.full((self.size[1], self.size[0], 3), 255, dtype=COLOR_DTYPE)
|
|
967
1149
|
|
|
968
1150
|
# Sample ONE color per object type in this biome (not per instance)
|
|
969
1151
|
# This gives objects of the same type the same color within a biome generation
|
|
@@ -975,9 +1157,9 @@ class ForagaxEnv(environment.Environment):
|
|
|
975
1157
|
biome_object_colors = jax.random.randint(
|
|
976
1158
|
color_key,
|
|
977
1159
|
(num_actual_objects, 3),
|
|
978
|
-
minval=0,
|
|
979
|
-
maxval=256,
|
|
980
|
-
dtype=
|
|
1160
|
+
minval=jnp.array(0, dtype=COLOR_DTYPE),
|
|
1161
|
+
maxval=jnp.array(256, dtype=jnp.int32), # randint range is often int32
|
|
1162
|
+
dtype=COLOR_DTYPE,
|
|
981
1163
|
)
|
|
982
1164
|
|
|
983
1165
|
# Assign colors based on object type (starting from index 1)
|
|
@@ -989,9 +1171,9 @@ class ForagaxEnv(environment.Environment):
|
|
|
989
1171
|
color_grid = jnp.where(obj_mask[..., None], obj_color, color_grid)
|
|
990
1172
|
|
|
991
1173
|
# Initialize parameters grid
|
|
992
|
-
num_object_params =
|
|
1174
|
+
num_object_params = 3 + 2 * self.num_fourier_terms
|
|
993
1175
|
params_grid = jnp.zeros(
|
|
994
|
-
(self.size[1], self.size[0], num_object_params), dtype=
|
|
1176
|
+
(self.size[1], self.size[0], num_object_params), dtype=PARAM_DTYPE
|
|
995
1177
|
)
|
|
996
1178
|
|
|
997
1179
|
# Generate per-object parameters for each object type
|
|
@@ -1018,7 +1200,9 @@ class ForagaxEnv(environment.Environment):
|
|
|
1018
1200
|
|
|
1019
1201
|
# Assign to all objects of this type in this biome
|
|
1020
1202
|
obj_mask = (object_grid == obj_idx) & biome_mask
|
|
1021
|
-
params_grid = jnp.where(
|
|
1203
|
+
params_grid = jnp.where(
|
|
1204
|
+
obj_mask[..., None], obj_params.astype(PARAM_DTYPE), params_grid
|
|
1205
|
+
)
|
|
1022
1206
|
|
|
1023
1207
|
return object_grid, color_grid, params_grid
|
|
1024
1208
|
|
|
@@ -1043,7 +1227,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
1043
1227
|
|
|
1044
1228
|
def state_space(self, params: EnvParams) -> spaces.Dict:
|
|
1045
1229
|
"""State space of the environment."""
|
|
1046
|
-
num_object_params =
|
|
1230
|
+
num_object_params = 3 + 2 * self.num_fourier_terms
|
|
1047
1231
|
return spaces.Dict(
|
|
1048
1232
|
{
|
|
1049
1233
|
"pos": spaces.Box(0, max(self.size), (2,), int),
|
|
@@ -1166,7 +1350,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
1166
1350
|
aperture = jnp.where(out_of_bounds[..., None], padding_value, values)
|
|
1167
1351
|
else:
|
|
1168
1352
|
# Object ID grid: use PADDING index
|
|
1169
|
-
padding_index = self.object_ids[-1]
|
|
1353
|
+
padding_index = self.object_ids[-1].astype(values.dtype)
|
|
1170
1354
|
aperture = jnp.where(out_of_bounds, padding_index, values)
|
|
1171
1355
|
else:
|
|
1172
1356
|
aperture = values
|
|
@@ -1215,9 +1399,11 @@ class ForagaxEnv(environment.Environment):
|
|
|
1215
1399
|
return jnp.zeros(obs_grid.shape + (0,), dtype=jnp.float32)
|
|
1216
1400
|
# Map object IDs to color channel indices
|
|
1217
1401
|
color_channels = jnp.where(
|
|
1218
|
-
obs_grid == 0,
|
|
1219
|
-
-1,
|
|
1220
|
-
jnp.take(
|
|
1402
|
+
obs_grid == jnp.array(0, dtype=obs_grid.dtype),
|
|
1403
|
+
jnp.array(-1, dtype=jnp.int32),
|
|
1404
|
+
jnp.take(
|
|
1405
|
+
self.object_to_color_map, obs_grid.astype(jnp.int32) - 1, axis=0
|
|
1406
|
+
),
|
|
1221
1407
|
)
|
|
1222
1408
|
obs = jax.nn.one_hot(color_channels, self.num_color_channels, axis=-1)
|
|
1223
1409
|
return obs
|
|
@@ -1237,9 +1423,8 @@ class ForagaxEnv(environment.Environment):
|
|
|
1237
1423
|
aperture_colors = color_aperture / 255.0
|
|
1238
1424
|
|
|
1239
1425
|
# Mask empty cells (object_id == 0) to white
|
|
1240
|
-
empty_mask = aperture == 0
|
|
1426
|
+
empty_mask = aperture == jnp.array(0, dtype=aperture.dtype)
|
|
1241
1427
|
white_color = jnp.ones(aperture_colors.shape, dtype=jnp.float32)
|
|
1242
|
-
|
|
1243
1428
|
obs = jnp.where(empty_mask[..., None], white_color, aperture_colors)
|
|
1244
1429
|
|
|
1245
1430
|
return obs
|
|
@@ -1279,12 +1464,15 @@ class ForagaxEnv(environment.Environment):
|
|
|
1279
1464
|
|
|
1280
1465
|
return spaces.Box(0, 1, obs_shape, float)
|
|
1281
1466
|
|
|
1282
|
-
def _compute_reward_grid(
|
|
1283
|
-
|
|
1467
|
+
def _compute_reward_grid(
|
|
1468
|
+
self, state: EnvState, object_id=None, state_params=None
|
|
1469
|
+
) -> jax.Array:
|
|
1470
|
+
"""Compute rewards for given positions. If no grid provided, uses full world."""
|
|
1471
|
+
if object_id is None:
|
|
1472
|
+
object_id = state.object_state.object_id
|
|
1473
|
+
if state_params is None:
|
|
1474
|
+
state_params = state.object_state.state_params
|
|
1284
1475
|
|
|
1285
|
-
Returns:
|
|
1286
|
-
Array of shape (H, W) with reward values for each cell
|
|
1287
|
-
"""
|
|
1288
1476
|
fixed_key = jax.random.key(0) # Fixed key for deterministic reward computation
|
|
1289
1477
|
|
|
1290
1478
|
def compute_reward(obj_id, params):
|
|
@@ -1296,9 +1484,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
1296
1484
|
lambda: 0.0,
|
|
1297
1485
|
)
|
|
1298
1486
|
|
|
1299
|
-
reward_grid = jax.vmap(jax.vmap(compute_reward))(
|
|
1300
|
-
state.object_state.object_id, state.object_state.state_params
|
|
1301
|
-
)
|
|
1487
|
+
reward_grid = jax.vmap(jax.vmap(compute_reward))(object_id, state_params)
|
|
1302
1488
|
return reward_grid
|
|
1303
1489
|
|
|
1304
1490
|
def _reward_to_color(self, reward: jax.Array) -> jax.Array:
|
|
@@ -1388,172 +1574,87 @@ class ForagaxEnv(environment.Environment):
|
|
|
1388
1574
|
|
|
1389
1575
|
img = jax.lax.fori_loop(0, len(self.object_ids), update_image, img)
|
|
1390
1576
|
|
|
1391
|
-
|
|
1392
|
-
|
|
1393
|
-
img = jax.image.resize(
|
|
1394
|
-
img,
|
|
1395
|
-
(self.size[1] * 3, self.size[0] * 3, 3),
|
|
1396
|
-
jax.image.ResizeMethod.NEAREST,
|
|
1397
|
-
)
|
|
1398
|
-
|
|
1399
|
-
# Compute rewards for all cells
|
|
1400
|
-
reward_grid = self._compute_reward_grid(state)
|
|
1401
|
-
|
|
1402
|
-
# Convert rewards to colors
|
|
1403
|
-
reward_colors = self._reward_to_color(reward_grid)
|
|
1404
|
-
|
|
1405
|
-
# Resize reward colors to match 3x scale and place in middle cells
|
|
1406
|
-
# We need to place reward colors at positions (i*3+1, j*3+1) for each (i,j)
|
|
1407
|
-
# Create index arrays for middle cells
|
|
1408
|
-
i_indices = jnp.arange(self.size[1])[:, None] * 3 + 1
|
|
1409
|
-
j_indices = jnp.arange(self.size[0])[None, :] * 3 + 1
|
|
1410
|
-
|
|
1411
|
-
# Broadcast and set middle cells
|
|
1412
|
-
img = img.at[i_indices, j_indices].set(reward_colors)
|
|
1413
|
-
|
|
1414
|
-
# Tint the agent's aperture
|
|
1577
|
+
# Define constants for all world modes
|
|
1578
|
+
alpha = 0.2
|
|
1415
1579
|
y_coords, x_coords, y_coords_adj, x_coords_adj = (
|
|
1416
1580
|
self._compute_aperture_coordinates(state.pos)
|
|
1417
1581
|
)
|
|
1418
1582
|
|
|
1419
|
-
alpha = 0.2
|
|
1420
|
-
agent_color = jnp.array(AGENT.color)
|
|
1421
|
-
|
|
1422
1583
|
if is_reward_mode:
|
|
1423
|
-
#
|
|
1424
|
-
|
|
1425
|
-
|
|
1426
|
-
|
|
1427
|
-
(self.size[1] * 3, self.size[0] * 3), dtype=bool
|
|
1428
|
-
)
|
|
1429
|
-
|
|
1430
|
-
# For each aperture cell, tint all 9 cells in its 3x3 block
|
|
1431
|
-
# Create meshgrid to get all aperture cell coordinates
|
|
1432
|
-
y_grid, x_grid = jnp.meshgrid(
|
|
1433
|
-
y_coords_adj.flatten(), x_coords_adj.flatten(), indexing="ij"
|
|
1434
|
-
)
|
|
1435
|
-
y_flat = y_grid.flatten()
|
|
1436
|
-
x_flat = x_grid.flatten()
|
|
1437
|
-
|
|
1438
|
-
# Create offset arrays for 3x3 blocks
|
|
1439
|
-
offsets = jnp.array(
|
|
1440
|
-
[[dy, dx] for dy in range(3) for dx in range(3)]
|
|
1441
|
-
)
|
|
1442
|
-
|
|
1443
|
-
# For each aperture cell, expand to 9 cells
|
|
1444
|
-
# We need to repeat each cell coordinate 9 times, then add offsets
|
|
1445
|
-
num_aperture_cells = y_flat.size
|
|
1446
|
-
y_base = jnp.repeat(
|
|
1447
|
-
y_flat * 3, 9
|
|
1448
|
-
) # Repeat each y coord 9 times and scale by 3
|
|
1449
|
-
x_base = jnp.repeat(
|
|
1450
|
-
x_flat * 3, 9
|
|
1451
|
-
) # Repeat each x coord 9 times and scale by 3
|
|
1452
|
-
y_offsets = jnp.tile(
|
|
1453
|
-
offsets[:, 0], num_aperture_cells
|
|
1454
|
-
) # Tile all 9 offsets
|
|
1455
|
-
x_offsets = jnp.tile(
|
|
1456
|
-
offsets[:, 1], num_aperture_cells
|
|
1457
|
-
) # Tile all 9 offsets
|
|
1458
|
-
y_expanded = y_base + y_offsets
|
|
1459
|
-
x_expanded = x_base + x_offsets
|
|
1460
|
-
|
|
1461
|
-
tint_mask = tint_mask.at[y_expanded, x_expanded].set(True)
|
|
1462
|
-
|
|
1463
|
-
original_colors = img
|
|
1464
|
-
tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
|
|
1465
|
-
img = jnp.where(tint_mask[..., None], tinted_colors, img)
|
|
1466
|
-
else:
|
|
1467
|
-
# Tint all 9 cells in each 3x3 block for aperture cells
|
|
1468
|
-
# Create meshgrid to get all aperture cell coordinates
|
|
1469
|
-
y_grid, x_grid = jnp.meshgrid(
|
|
1470
|
-
y_coords_adj.flatten(), x_coords_adj.flatten(), indexing="ij"
|
|
1471
|
-
)
|
|
1472
|
-
y_flat = y_grid.flatten()
|
|
1473
|
-
x_flat = x_grid.flatten()
|
|
1584
|
+
# Construct 3x intermediate image
|
|
1585
|
+
# Each cell is 3x3, with reward color in center
|
|
1586
|
+
reward_grid = self._compute_reward_grid(state)
|
|
1587
|
+
reward_colors = self._reward_to_color(reward_grid)
|
|
1474
1588
|
|
|
1475
|
-
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
|
|
1589
|
+
# Each cell has its base color in 8 pixels and reward color in 1 (center)
|
|
1590
|
+
# Create a 3x3 pattern mask for center pixels
|
|
1591
|
+
cell_mask = jnp.array(
|
|
1592
|
+
[[False, False, False], [False, True, False], [False, False, False]]
|
|
1593
|
+
)
|
|
1594
|
+
grid_reward_mask = jnp.tile(cell_mask, (self.size[1], self.size[0]))
|
|
1479
1595
|
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
|
|
1484
|
-
y_flat * 3, 9
|
|
1485
|
-
) # Repeat each y coord 9 times and scale by 3
|
|
1486
|
-
x_base = jnp.repeat(
|
|
1487
|
-
x_flat * 3, 9
|
|
1488
|
-
) # Repeat each x coord 9 times and scale by 3
|
|
1489
|
-
y_offsets = jnp.tile(
|
|
1490
|
-
offsets[:, 0], num_aperture_cells
|
|
1491
|
-
) # Tile all 9 offsets
|
|
1492
|
-
x_offsets = jnp.tile(
|
|
1493
|
-
offsets[:, 1], num_aperture_cells
|
|
1494
|
-
) # Tile all 9 offsets
|
|
1495
|
-
y_expanded = y_base + y_offsets
|
|
1496
|
-
x_expanded = x_base + x_offsets
|
|
1497
|
-
|
|
1498
|
-
# Get original colors and tint them
|
|
1499
|
-
original_colors = img[y_expanded, x_expanded]
|
|
1500
|
-
tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
|
|
1501
|
-
img = img.at[y_expanded, x_expanded].set(tinted_colors)
|
|
1502
|
-
|
|
1503
|
-
# Agent color - set all 9 cells of the agent's 3x3 block
|
|
1504
|
-
agent_y, agent_x = state.pos[1], state.pos[0]
|
|
1505
|
-
agent_offsets = jnp.array(
|
|
1506
|
-
[[dy, dx] for dy in range(3) for dx in range(3)]
|
|
1596
|
+
# Repeat base colors and rewards to 3x3
|
|
1597
|
+
base_img_x3 = jnp.repeat(jnp.repeat(img, 3, axis=0), 3, axis=1)
|
|
1598
|
+
reward_colors_x3 = jnp.repeat(
|
|
1599
|
+
jnp.repeat(reward_colors, 3, axis=0), 3, axis=1
|
|
1507
1600
|
)
|
|
1508
|
-
|
|
1509
|
-
|
|
1510
|
-
img =
|
|
1511
|
-
|
|
1601
|
+
|
|
1602
|
+
# Composite base and reward colors
|
|
1603
|
+
img = jnp.where(
|
|
1604
|
+
grid_reward_mask[..., None], reward_colors_x3, base_img_x3
|
|
1512
1605
|
)
|
|
1513
1606
|
|
|
1514
|
-
#
|
|
1515
|
-
|
|
1516
|
-
|
|
1517
|
-
|
|
1518
|
-
|
|
1607
|
+
# Tint the aperture region at 3x scale
|
|
1608
|
+
aperture_mask = jnp.zeros((self.size[1], self.size[0]), dtype=bool)
|
|
1609
|
+
aperture_mask = aperture_mask.at[y_coords_adj, x_coords_adj].set(True)
|
|
1610
|
+
aperture_mask_x3 = jnp.repeat(
|
|
1611
|
+
jnp.repeat(aperture_mask, 3, axis=0), 3, axis=1
|
|
1519
1612
|
)
|
|
1613
|
+
|
|
1614
|
+
tinted_img = (
|
|
1615
|
+
(1.0 - alpha) * img.astype(jnp.float32)
|
|
1616
|
+
+ alpha * self.agent_color_jax.astype(jnp.float32)
|
|
1617
|
+
).astype(jnp.uint8)
|
|
1618
|
+
img = jnp.where(aperture_mask_x3[..., None], tinted_img, img)
|
|
1619
|
+
|
|
1620
|
+
# Set agent center block
|
|
1621
|
+
agent_mask = jnp.zeros((self.size[1], self.size[0]), dtype=bool)
|
|
1622
|
+
agent_mask = agent_mask.at[state.pos[1], state.pos[0]].set(True)
|
|
1623
|
+
agent_mask_x3 = jnp.repeat(jnp.repeat(agent_mask, 3, axis=0), 3, axis=1)
|
|
1624
|
+
img = jnp.where(agent_mask_x3[..., None], self.agent_color_jax, img)
|
|
1625
|
+
|
|
1626
|
+
# Final scale by 8 to get 24x
|
|
1627
|
+
img = jnp.repeat(jnp.repeat(img, 8, axis=0), 8, axis=1)
|
|
1520
1628
|
else:
|
|
1521
1629
|
# Standard rendering without reward visualization
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
tint_mask = jnp.zeros((self.size[1], self.size[0]), dtype=int)
|
|
1525
|
-
tint_mask = tint_mask.at[y_coords_adj, x_coords_adj].set(1)
|
|
1526
|
-
# Apply tint to masked positions
|
|
1527
|
-
original_colors = img
|
|
1528
|
-
tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
|
|
1529
|
-
img = jnp.where(tint_mask[..., None], tinted_colors, img)
|
|
1530
|
-
else:
|
|
1531
|
-
original_colors = img[y_coords_adj, x_coords_adj]
|
|
1532
|
-
tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
|
|
1533
|
-
img = img.at[y_coords_adj, x_coords_adj].set(tinted_colors)
|
|
1630
|
+
aperture_mask = jnp.zeros((self.size[1], self.size[0]), dtype=bool)
|
|
1631
|
+
aperture_mask = aperture_mask.at[y_coords_adj, x_coords_adj].set(True)
|
|
1534
1632
|
|
|
1535
|
-
|
|
1536
|
-
|
|
1633
|
+
tinted_img = (
|
|
1634
|
+
(1.0 - alpha) * img.astype(jnp.float32)
|
|
1635
|
+
+ alpha * self.agent_color_jax.astype(jnp.float32)
|
|
1636
|
+
).astype(jnp.uint8)
|
|
1637
|
+
img = jnp.where(aperture_mask[..., None], tinted_img, img)
|
|
1537
1638
|
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
)
|
|
1639
|
+
# Set agent
|
|
1640
|
+
img = img.at[state.pos[1], state.pos[0]].set(self.agent_color_jax)
|
|
1641
|
+
# Scale by 24
|
|
1642
|
+
img = jnp.repeat(jnp.repeat(img, 24, axis=0), 24, axis=1)
|
|
1543
1643
|
|
|
1544
1644
|
if is_true_mode:
|
|
1545
|
-
# Apply true object borders
|
|
1546
|
-
render_grid = state.object_state.object_id
|
|
1645
|
+
# Apply true object borders
|
|
1547
1646
|
img = apply_true_borders(
|
|
1548
|
-
img,
|
|
1647
|
+
img, state.object_state.object_id, self.size, len(self.object_ids)
|
|
1549
1648
|
)
|
|
1550
1649
|
|
|
1551
|
-
# Add grid lines
|
|
1552
|
-
|
|
1553
|
-
|
|
1554
|
-
|
|
1555
|
-
|
|
1556
|
-
|
|
1650
|
+
# Add grid lines using masking instead of slice-setting
|
|
1651
|
+
row_grid = (jnp.arange(self.size[1] * 24) % 24) == 0
|
|
1652
|
+
col_grid = (jnp.arange(self.size[0] * 24) % 24) == 0
|
|
1653
|
+
# skip first rows/cols as they are borders or managed by caller
|
|
1654
|
+
row_grid = row_grid.at[0].set(False)
|
|
1655
|
+
col_grid = col_grid.at[0].set(False)
|
|
1656
|
+
grid_mask = row_grid[:, None] | col_grid[None, :]
|
|
1657
|
+
img = jnp.where(grid_mask[..., None], self.grid_color_jax, img)
|
|
1557
1658
|
|
|
1558
1659
|
elif is_aperture_mode:
|
|
1559
1660
|
obs_grid = state.object_state.object_id
|
|
@@ -1600,11 +1701,16 @@ class ForagaxEnv(environment.Environment):
|
|
|
1600
1701
|
self._compute_aperture_coordinates(state.pos)
|
|
1601
1702
|
)
|
|
1602
1703
|
|
|
1603
|
-
# Get reward grid for
|
|
1604
|
-
|
|
1605
|
-
|
|
1606
|
-
|
|
1607
|
-
|
|
1704
|
+
# Get reward grid only for aperture region
|
|
1705
|
+
aperture_object_ids = state.object_state.object_id[
|
|
1706
|
+
y_coords_adj, x_coords_adj
|
|
1707
|
+
]
|
|
1708
|
+
aperture_params = state.object_state.state_params[
|
|
1709
|
+
y_coords_adj, x_coords_adj
|
|
1710
|
+
]
|
|
1711
|
+
aperture_rewards = self._compute_reward_grid(
|
|
1712
|
+
state, aperture_object_ids, aperture_params
|
|
1713
|
+
)
|
|
1608
1714
|
|
|
1609
1715
|
# Convert rewards to colors
|
|
1610
1716
|
reward_colors = self._reward_to_color(aperture_rewards)
|