continual-foragax 0.30.1__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {continual_foragax-0.30.1.dist-info → continual_foragax-0.32.0.dist-info}/METADATA +1 -1
- {continual_foragax-0.30.1.dist-info → continual_foragax-0.32.0.dist-info}/RECORD +8 -8
- foragax/env.py +783 -132
- foragax/objects.py +452 -5
- foragax/registry.py +83 -0
- {continual_foragax-0.30.1.dist-info → continual_foragax-0.32.0.dist-info}/WHEEL +0 -0
- {continual_foragax-0.30.1.dist-info → continual_foragax-0.32.0.dist-info}/entry_points.txt +0 -0
- {continual_foragax-0.30.1.dist-info → continual_foragax-0.32.0.dist-info}/top_level.txt +0 -0
foragax/env.py
CHANGED
|
@@ -19,6 +19,7 @@ from foragax.objects import (
|
|
|
19
19
|
EMPTY,
|
|
20
20
|
PADDING,
|
|
21
21
|
BaseForagaxObject,
|
|
22
|
+
FourierObject,
|
|
22
23
|
WeatherObject,
|
|
23
24
|
)
|
|
24
25
|
from foragax.rendering import apply_true_borders
|
|
@@ -42,6 +43,48 @@ DIRECTIONS = jnp.array(
|
|
|
42
43
|
)
|
|
43
44
|
|
|
44
45
|
|
|
46
|
+
@struct.dataclass
|
|
47
|
+
class ObjectState:
|
|
48
|
+
"""Per-cell object state information.
|
|
49
|
+
|
|
50
|
+
This struct encapsulates all state information about objects in the grid:
|
|
51
|
+
- object_id: (H, W) The object type (0 for empty, positive for object type)
|
|
52
|
+
- respawn_timer: (H, W) Countdown timer for respawning (0 = no timer, positive = countdown remaining)
|
|
53
|
+
- respawn_object_id: (H, W) What object type will spawn when timer reaches 0
|
|
54
|
+
- spawn_time: (H, W) When each object was spawned (for expiry tracking)
|
|
55
|
+
- color: (H, W, 3) RGB color for each object instance (for dynamic biomes)
|
|
56
|
+
- generation: (H, W) Which biome generation each object belongs to
|
|
57
|
+
- state_params: (H, W, N) Per-instance parameters (e.g., Fourier coefficients)
|
|
58
|
+
- biome_id: (H, W) Which biome each cell belongs to
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
object_id: jax.Array # (H, W) - Object type ID (0 = empty, >0 = object type)
|
|
62
|
+
respawn_timer: (
|
|
63
|
+
jax.Array
|
|
64
|
+
) # (H, W) - Respawn countdown (0 = no timer, >0 = countdown)
|
|
65
|
+
respawn_object_id: jax.Array # (H, W) - Object type to spawn when timer reaches 0
|
|
66
|
+
spawn_time: jax.Array # (H, W) - Timestep when object spawned
|
|
67
|
+
color: jax.Array # (H, W, 3) - RGB color per instance
|
|
68
|
+
generation: jax.Array # (H, W) - Biome generation number
|
|
69
|
+
state_params: jax.Array # (H, W, N) - Per-instance parameters
|
|
70
|
+
biome_id: jax.Array # (H, W) - Biome assignment for each cell
|
|
71
|
+
|
|
72
|
+
@classmethod
|
|
73
|
+
def create_empty(cls, size: Tuple[int, int], num_params: int) -> "ObjectState":
|
|
74
|
+
"""Create an empty ObjectState for the given grid size."""
|
|
75
|
+
h, w = size[1], size[0]
|
|
76
|
+
return cls(
|
|
77
|
+
object_id=jnp.zeros((h, w), dtype=int),
|
|
78
|
+
respawn_timer=jnp.zeros((h, w), dtype=int),
|
|
79
|
+
respawn_object_id=jnp.zeros((h, w), dtype=int),
|
|
80
|
+
spawn_time=jnp.zeros((h, w), dtype=int),
|
|
81
|
+
color=jnp.full((h, w, 3), 255, dtype=jnp.uint8),
|
|
82
|
+
generation=jnp.zeros((h, w), dtype=int),
|
|
83
|
+
state_params=jnp.zeros((h, w, num_params), dtype=jnp.float32),
|
|
84
|
+
biome_id=jnp.full((h, w), -1, dtype=int),
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
45
88
|
@dataclass
|
|
46
89
|
class Biome:
|
|
47
90
|
# Object generation frequencies for this biome
|
|
@@ -55,13 +98,22 @@ class EnvParams(environment.EnvParams):
|
|
|
55
98
|
max_steps_in_episode: Union[int, None]
|
|
56
99
|
|
|
57
100
|
|
|
101
|
+
@struct.dataclass
|
|
102
|
+
class BiomeState:
|
|
103
|
+
"""Biome-level tracking state (num_biomes,)."""
|
|
104
|
+
|
|
105
|
+
consumption_count: jax.Array # objects consumed per biome
|
|
106
|
+
total_objects: jax.Array # total objects spawned per biome
|
|
107
|
+
generation: jax.Array # current generation per biome
|
|
108
|
+
|
|
109
|
+
|
|
58
110
|
@struct.dataclass
|
|
59
111
|
class EnvState(environment.EnvState):
|
|
60
112
|
pos: jax.Array
|
|
61
|
-
object_grid: jax.Array
|
|
62
|
-
biome_grid: jax.Array
|
|
63
113
|
time: int
|
|
64
114
|
digestion_buffer: jax.Array
|
|
115
|
+
object_state: ObjectState
|
|
116
|
+
biome_state: BiomeState
|
|
65
117
|
|
|
66
118
|
|
|
67
119
|
class ForagaxEnv(environment.Environment):
|
|
@@ -78,6 +130,8 @@ class ForagaxEnv(environment.Environment):
|
|
|
78
130
|
deterministic_spawn: bool = False,
|
|
79
131
|
teleport_interval: Optional[int] = None,
|
|
80
132
|
observation_type: str = "object",
|
|
133
|
+
dynamic_biomes: bool = False,
|
|
134
|
+
biome_consumption_threshold: float = 0.9,
|
|
81
135
|
):
|
|
82
136
|
super().__init__()
|
|
83
137
|
self._name = name
|
|
@@ -99,11 +153,24 @@ class ForagaxEnv(environment.Environment):
|
|
|
99
153
|
self.nowrap = nowrap
|
|
100
154
|
self.deterministic_spawn = deterministic_spawn
|
|
101
155
|
self.teleport_interval = teleport_interval
|
|
156
|
+
self.dynamic_biomes = dynamic_biomes
|
|
157
|
+
self.biome_consumption_threshold = biome_consumption_threshold
|
|
158
|
+
|
|
102
159
|
objects = (EMPTY,) + objects
|
|
103
160
|
if self.nowrap and not self.full_world:
|
|
104
161
|
objects = objects + (PADDING,)
|
|
105
162
|
self.objects = objects
|
|
106
163
|
|
|
164
|
+
# Infer num_fourier_terms from objects
|
|
165
|
+
self.num_fourier_terms = max(
|
|
166
|
+
(
|
|
167
|
+
obj.num_fourier_terms
|
|
168
|
+
for obj in self.objects
|
|
169
|
+
if isinstance(obj, FourierObject)
|
|
170
|
+
),
|
|
171
|
+
default=0,
|
|
172
|
+
)
|
|
173
|
+
|
|
107
174
|
# JIT-compatible versions of object and biome properties
|
|
108
175
|
self.object_ids = jnp.arange(len(objects))
|
|
109
176
|
self.object_blocking = jnp.array([o.blocking for o in objects])
|
|
@@ -114,6 +181,15 @@ class ForagaxEnv(environment.Environment):
|
|
|
114
181
|
self.reward_fns = [o.reward for o in objects]
|
|
115
182
|
self.regen_delay_fns = [o.regen_delay for o in objects]
|
|
116
183
|
self.reward_delay_fns = [o.reward_delay for o in objects]
|
|
184
|
+
self.expiry_regen_delay_fns = [o.expiry_regen_delay for o in objects]
|
|
185
|
+
|
|
186
|
+
# Expiry times per object (None becomes -1 for no expiry)
|
|
187
|
+
self.object_expiry_time = jnp.array(
|
|
188
|
+
[o.expiry_time if o.expiry_time is not None else -1 for o in objects]
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# Check if any objects can expire
|
|
192
|
+
self.has_expiring_objects = jnp.any(self.object_expiry_time >= 0)
|
|
117
193
|
|
|
118
194
|
# Compute reward steps per object (using max_reward_delay attribute)
|
|
119
195
|
object_max_reward_delay = jnp.array([o.max_reward_delay for o in objects])
|
|
@@ -131,6 +207,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
131
207
|
[b.stop if b.stop is not None else (-1, -1) for b in biomes]
|
|
132
208
|
)
|
|
133
209
|
self.biome_sizes = np.prod(self.biome_stops - self.biome_starts, axis=1)
|
|
210
|
+
self.biome_sizes_jax = jnp.array(self.biome_sizes) # JAX version for indexing
|
|
134
211
|
self.biome_starts_jax = jnp.array(self.biome_starts)
|
|
135
212
|
self.biome_stops_jax = jnp.array(self.biome_stops)
|
|
136
213
|
biome_centers = []
|
|
@@ -142,6 +219,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
142
219
|
biome_centers.append((center_x, center_y))
|
|
143
220
|
self.biome_centers_jax = jnp.array(biome_centers)
|
|
144
221
|
self.biome_masks = []
|
|
222
|
+
biome_masks_array = []
|
|
145
223
|
for i in range(self.biome_object_frequencies.shape[0]):
|
|
146
224
|
# Create mask for the biome
|
|
147
225
|
start = jax.lax.select(
|
|
@@ -163,6 +241,10 @@ class ForagaxEnv(environment.Environment):
|
|
|
163
241
|
& (cols < stop[0])
|
|
164
242
|
)
|
|
165
243
|
self.biome_masks.append(mask)
|
|
244
|
+
biome_masks_array.append(mask)
|
|
245
|
+
|
|
246
|
+
# Convert to JAX array for indexing in JIT-compiled code
|
|
247
|
+
self.biome_masks_array = jnp.array(biome_masks_array)
|
|
166
248
|
|
|
167
249
|
# Compute unique colors and mapping for partial observability (for 'color' observation_type)
|
|
168
250
|
# Exclude EMPTY (index 0) from color channels
|
|
@@ -193,6 +275,90 @@ class ForagaxEnv(environment.Environment):
|
|
|
193
275
|
max_steps_in_episode=None,
|
|
194
276
|
)
|
|
195
277
|
|
|
278
|
+
def _place_timer(
|
|
279
|
+
self,
|
|
280
|
+
object_state: ObjectState,
|
|
281
|
+
y: int,
|
|
282
|
+
x: int,
|
|
283
|
+
object_type: int,
|
|
284
|
+
timer_val: int,
|
|
285
|
+
random_respawn: bool,
|
|
286
|
+
rand_key: jax.Array,
|
|
287
|
+
) -> ObjectState:
|
|
288
|
+
"""Place a timer at position or randomly within the same biome.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
object_state: Current object state
|
|
292
|
+
y, x: Original position
|
|
293
|
+
object_type: The object type ID that will respawn (0 for permanent removal)
|
|
294
|
+
timer_val: Timer countdown value (0 for permanent removal, positive for countdown)
|
|
295
|
+
random_respawn: If True, place at random location in same biome
|
|
296
|
+
rand_key: Random key for random placement
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
Updated object_state with timer placed
|
|
300
|
+
"""
|
|
301
|
+
|
|
302
|
+
# Handle permanent removal (timer_val == 0)
|
|
303
|
+
def place_empty():
|
|
304
|
+
return object_state.replace(
|
|
305
|
+
object_id=object_state.object_id.at[y, x].set(0),
|
|
306
|
+
respawn_timer=object_state.respawn_timer.at[y, x].set(0),
|
|
307
|
+
respawn_object_id=object_state.respawn_object_id.at[y, x].set(0),
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
# Handle timer placement
|
|
311
|
+
def place_timer():
|
|
312
|
+
# Non-random: place at original position
|
|
313
|
+
def place_at_position():
|
|
314
|
+
return object_state.replace(
|
|
315
|
+
object_id=object_state.object_id.at[y, x].set(0),
|
|
316
|
+
respawn_timer=object_state.respawn_timer.at[y, x].set(timer_val),
|
|
317
|
+
respawn_object_id=object_state.respawn_object_id.at[y, x].set(
|
|
318
|
+
object_type
|
|
319
|
+
),
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
# Random: place at random location in same biome
|
|
323
|
+
def place_randomly():
|
|
324
|
+
# Clear the collected object's position
|
|
325
|
+
new_object_id = object_state.object_id.at[y, x].set(0)
|
|
326
|
+
new_respawn_timer = object_state.respawn_timer.at[y, x].set(0)
|
|
327
|
+
new_respawn_object_id = object_state.respawn_object_id.at[y, x].set(0)
|
|
328
|
+
|
|
329
|
+
# Find valid spawn locations in the same biome
|
|
330
|
+
biome_id = object_state.biome_id[y, x]
|
|
331
|
+
biome_mask = object_state.biome_id == biome_id
|
|
332
|
+
empty_mask = new_object_id == 0
|
|
333
|
+
no_timer_mask = new_respawn_timer == 0
|
|
334
|
+
valid_spawn_mask = biome_mask & empty_mask & no_timer_mask
|
|
335
|
+
num_valid_spawns = jnp.sum(valid_spawn_mask)
|
|
336
|
+
|
|
337
|
+
y_indices, x_indices = jnp.nonzero(
|
|
338
|
+
valid_spawn_mask, size=self.size[0] * self.size[1], fill_value=-1
|
|
339
|
+
)
|
|
340
|
+
valid_spawn_indices = jnp.stack([y_indices, x_indices], axis=1)
|
|
341
|
+
random_idx = jax.random.randint(rand_key, (), 0, num_valid_spawns)
|
|
342
|
+
new_spawn_pos = valid_spawn_indices[random_idx]
|
|
343
|
+
|
|
344
|
+
# Place timer at the new random position
|
|
345
|
+
new_respawn_timer = new_respawn_timer.at[
|
|
346
|
+
new_spawn_pos[0], new_spawn_pos[1]
|
|
347
|
+
].set(timer_val)
|
|
348
|
+
new_respawn_object_id = new_respawn_object_id.at[
|
|
349
|
+
new_spawn_pos[0], new_spawn_pos[1]
|
|
350
|
+
].set(object_type)
|
|
351
|
+
|
|
352
|
+
return object_state.replace(
|
|
353
|
+
object_id=new_object_id,
|
|
354
|
+
respawn_timer=new_respawn_timer,
|
|
355
|
+
respawn_object_id=new_respawn_object_id,
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
return jax.lax.cond(random_respawn, place_randomly, place_at_position)
|
|
359
|
+
|
|
360
|
+
return jax.lax.cond(timer_val == 0, place_empty, place_timer)
|
|
361
|
+
|
|
196
362
|
def step_env(
|
|
197
363
|
self,
|
|
198
364
|
key: jax.Array,
|
|
@@ -201,9 +367,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
201
367
|
params: EnvParams,
|
|
202
368
|
) -> Tuple[jax.Array, EnvState, jax.Array, jax.Array, Dict[Any, Any]]:
|
|
203
369
|
"""Perform single timestep state transition."""
|
|
204
|
-
|
|
205
|
-
# Decode the object grid: positive values are objects, negative are timers (treat as empty)
|
|
206
|
-
current_objects = jnp.maximum(0, state.object_grid)
|
|
370
|
+
current_objects = state.object_state.object_id
|
|
207
371
|
|
|
208
372
|
# 1. UPDATE AGENT POSITION
|
|
209
373
|
direction = DIRECTIONS[action]
|
|
@@ -246,9 +410,12 @@ class ForagaxEnv(environment.Environment):
|
|
|
246
410
|
# Handle digestion: add reward to buffer if collected
|
|
247
411
|
digestion_buffer = state.digestion_buffer
|
|
248
412
|
key, reward_subkey = jax.random.split(key)
|
|
413
|
+
|
|
414
|
+
object_params = state.object_state.state_params[pos[1], pos[0]]
|
|
249
415
|
object_reward = jax.lax.switch(
|
|
250
|
-
obj_at_pos, self.reward_fns, state.time, reward_subkey
|
|
416
|
+
obj_at_pos, self.reward_fns, state.time, reward_subkey, object_params
|
|
251
417
|
)
|
|
418
|
+
|
|
252
419
|
key, digestion_subkey = jax.random.split(key)
|
|
253
420
|
reward_delay = jax.lax.switch(
|
|
254
421
|
obj_at_pos, self.reward_delay_fns, state.time, digestion_subkey
|
|
@@ -271,61 +438,120 @@ class ForagaxEnv(environment.Environment):
|
|
|
271
438
|
# 3. HANDLE OBJECT COLLECTION AND RESPAWNING
|
|
272
439
|
key, regen_subkey, rand_key = jax.random.split(key, 3)
|
|
273
440
|
|
|
274
|
-
# Decrement timers
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
441
|
+
# Decrement respawn timers
|
|
442
|
+
has_timer = state.object_state.respawn_timer > 0
|
|
443
|
+
new_respawn_timer = jnp.where(
|
|
444
|
+
has_timer,
|
|
445
|
+
state.object_state.respawn_timer - 1,
|
|
446
|
+
state.object_state.respawn_timer,
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
# Track which cells have timers that just reached 0
|
|
450
|
+
just_respawned = has_timer & (new_respawn_timer == 0)
|
|
451
|
+
|
|
452
|
+
# Respawn objects where timer reached 0
|
|
453
|
+
new_object_id = jnp.where(
|
|
454
|
+
just_respawned,
|
|
455
|
+
state.object_state.respawn_object_id,
|
|
456
|
+
state.object_state.object_id,
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
# Clear respawn_object_id for cells that just respawned
|
|
460
|
+
new_respawn_object_id = jnp.where(
|
|
461
|
+
just_respawned,
|
|
462
|
+
0,
|
|
463
|
+
state.object_state.respawn_object_id,
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
# Update spawn times for objects that just respawned
|
|
467
|
+
spawn_time = jnp.where(
|
|
468
|
+
just_respawned, state.time, state.object_state.spawn_time
|
|
278
469
|
)
|
|
279
470
|
|
|
280
471
|
# Collect object: set a timer
|
|
281
472
|
regen_delay = jax.lax.switch(
|
|
282
473
|
obj_at_pos, self.regen_delay_fns, state.time, regen_subkey
|
|
283
474
|
)
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
475
|
+
timer_countdown = jax.lax.cond(
|
|
476
|
+
regen_delay == jnp.iinfo(jnp.int32).max,
|
|
477
|
+
lambda: 0, # No timer (permanent removal)
|
|
478
|
+
lambda: regen_delay + 1, # Timer countdown
|
|
479
|
+
)
|
|
288
480
|
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
481
|
+
# If collected, replace object with timer; otherwise, keep it
|
|
482
|
+
val_at_pos = current_objects[pos[1], pos[0]]
|
|
483
|
+
# Use original should_collect for consumption tracking
|
|
484
|
+
should_collect_now = is_collectable & (val_at_pos > 0)
|
|
485
|
+
|
|
486
|
+
# Create updated object state with new respawn_timer, object_id, and spawn_time
|
|
487
|
+
object_state = state.object_state.replace(
|
|
488
|
+
object_id=new_object_id,
|
|
489
|
+
respawn_timer=new_respawn_timer,
|
|
490
|
+
respawn_object_id=new_respawn_object_id,
|
|
491
|
+
spawn_time=spawn_time,
|
|
492
|
+
)
|
|
292
493
|
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
494
|
+
# Place timer on collection
|
|
495
|
+
object_state = jax.lax.cond(
|
|
496
|
+
should_collect_now,
|
|
497
|
+
lambda: self._place_timer(
|
|
498
|
+
object_state,
|
|
499
|
+
pos[1],
|
|
500
|
+
pos[0],
|
|
501
|
+
obj_at_pos, # object type
|
|
502
|
+
timer_countdown, # timer value
|
|
503
|
+
self.object_random_respawn[obj_at_pos],
|
|
504
|
+
rand_key,
|
|
505
|
+
),
|
|
506
|
+
lambda: object_state,
|
|
507
|
+
)
|
|
298
508
|
|
|
299
|
-
|
|
509
|
+
# Clear color grid when object is collected
|
|
510
|
+
object_state = jax.lax.cond(
|
|
511
|
+
should_collect_now,
|
|
512
|
+
lambda: object_state.replace(
|
|
513
|
+
color=object_state.color.at[pos[1], pos[0]].set(
|
|
514
|
+
jnp.full((3,), 255, dtype=jnp.uint8)
|
|
515
|
+
)
|
|
516
|
+
),
|
|
517
|
+
lambda: object_state,
|
|
518
|
+
)
|
|
300
519
|
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
520
|
+
# 3.5. HANDLE OBJECT EXPIRY
|
|
521
|
+
# Only process expiry if there are objects that can expire
|
|
522
|
+
key, object_state = self.expire_objects(key, state, object_state)
|
|
523
|
+
|
|
524
|
+
# 3.6. HANDLE DYNAMIC BIOME CONSUMPTION AND RESPAWNING
|
|
525
|
+
if self.dynamic_biomes:
|
|
526
|
+
# Update consumption count if an object was collected
|
|
527
|
+
# Only count if the object belongs to the current generation of its biome
|
|
528
|
+
collected_biome_id = object_state.biome_id[pos[1], pos[0]]
|
|
529
|
+
object_gen_at_pos = object_state.generation[pos[1], pos[0]]
|
|
530
|
+
current_biome_gen = state.biome_state.generation[collected_biome_id]
|
|
531
|
+
is_current_generation = object_gen_at_pos == current_biome_gen
|
|
532
|
+
|
|
533
|
+
biome_consumption_count = state.biome_state.consumption_count
|
|
534
|
+
biome_consumption_count = jax.lax.cond(
|
|
535
|
+
should_collect & is_current_generation,
|
|
536
|
+
lambda: biome_consumption_count.at[collected_biome_id].add(1),
|
|
537
|
+
lambda: biome_consumption_count,
|
|
304
538
|
)
|
|
305
|
-
valid_spawn_indices = jnp.stack([y_indices, x_indices], axis=1)
|
|
306
|
-
|
|
307
|
-
# Select a random valid location
|
|
308
|
-
random_idx = jax.random.randint(rand_key, (), 0, num_valid_spawns)
|
|
309
|
-
new_spawn_pos = valid_spawn_indices[random_idx]
|
|
310
539
|
|
|
311
|
-
#
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
),
|
|
327
|
-
lambda: object_grid,
|
|
328
|
-
)
|
|
540
|
+
# Check each biome for threshold crossing and respawn if needed
|
|
541
|
+
key, respawn_key = jax.random.split(key)
|
|
542
|
+
biome_state = BiomeState(
|
|
543
|
+
consumption_count=biome_consumption_count,
|
|
544
|
+
total_objects=state.biome_state.total_objects,
|
|
545
|
+
generation=state.biome_state.generation,
|
|
546
|
+
)
|
|
547
|
+
object_state, biome_state, respawn_key = self._check_and_respawn_biomes(
|
|
548
|
+
object_state,
|
|
549
|
+
biome_state,
|
|
550
|
+
state.time,
|
|
551
|
+
respawn_key,
|
|
552
|
+
)
|
|
553
|
+
else:
|
|
554
|
+
biome_state = state.biome_state
|
|
329
555
|
|
|
330
556
|
info = {"discount": self.discount(state, params)}
|
|
331
557
|
temperatures = jnp.zeros(len(self.objects))
|
|
@@ -335,16 +561,16 @@ class ForagaxEnv(environment.Environment):
|
|
|
335
561
|
get_temperature(obj.rewards, state.time, obj.repeat)
|
|
336
562
|
)
|
|
337
563
|
info["temperatures"] = temperatures
|
|
338
|
-
info["biome_id"] =
|
|
564
|
+
info["biome_id"] = object_state.biome_id[pos[1], pos[0]]
|
|
339
565
|
info["object_collected_id"] = jax.lax.select(should_collect, obj_at_pos, -1)
|
|
340
566
|
|
|
341
567
|
# 4. UPDATE STATE
|
|
342
568
|
state = EnvState(
|
|
343
569
|
pos=pos,
|
|
344
|
-
object_grid=object_grid,
|
|
345
|
-
biome_grid=state.biome_grid,
|
|
346
570
|
time=state.time + 1,
|
|
347
571
|
digestion_buffer=digestion_buffer,
|
|
572
|
+
object_state=object_state,
|
|
573
|
+
biome_state=biome_state,
|
|
348
574
|
)
|
|
349
575
|
|
|
350
576
|
done = self.is_terminal(state, params)
|
|
@@ -356,56 +582,401 @@ class ForagaxEnv(environment.Environment):
|
|
|
356
582
|
info,
|
|
357
583
|
)
|
|
358
584
|
|
|
585
|
+
def expire_objects(
|
|
586
|
+
self, key, state, object_state: ObjectState
|
|
587
|
+
) -> Tuple[jax.Array, ObjectState]:
|
|
588
|
+
if self.has_expiring_objects:
|
|
589
|
+
# Check each cell for objects that have exceeded their expiry time
|
|
590
|
+
current_objects_for_expiry = object_state.object_id
|
|
591
|
+
|
|
592
|
+
# Calculate age of each object (current_time - spawn_time)
|
|
593
|
+
object_ages = state.time - object_state.spawn_time
|
|
594
|
+
|
|
595
|
+
# Get expiry time for each object type in the grid
|
|
596
|
+
expiry_times = self.object_expiry_time[current_objects_for_expiry]
|
|
597
|
+
|
|
598
|
+
# Check if object should expire (age >= expiry_time and expiry_time >= 0)
|
|
599
|
+
should_expire = (
|
|
600
|
+
(object_ages >= expiry_times)
|
|
601
|
+
& (expiry_times >= 0)
|
|
602
|
+
& (current_objects_for_expiry > 0)
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
# Count how many objects actually need to expire
|
|
606
|
+
num_expiring = jnp.sum(should_expire)
|
|
607
|
+
|
|
608
|
+
# Only process expiry if there are actually objects to expire
|
|
609
|
+
def process_expiries():
|
|
610
|
+
# Get positions of objects that should expire
|
|
611
|
+
# Use nonzero with fixed size to maintain JIT compatibility
|
|
612
|
+
max_objects = self.size[0] * self.size[1]
|
|
613
|
+
y_indices, x_indices = jnp.nonzero(
|
|
614
|
+
should_expire, size=max_objects, fill_value=-1
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
key_local, expiry_key = jax.random.split(key)
|
|
618
|
+
|
|
619
|
+
def process_one_expiry(carry, i):
|
|
620
|
+
obj_state = carry
|
|
621
|
+
y = y_indices[i]
|
|
622
|
+
x = x_indices[i]
|
|
623
|
+
|
|
624
|
+
# Skip if this is a padding index (from fill_value)
|
|
625
|
+
is_valid = (y >= 0) & (x >= 0)
|
|
626
|
+
|
|
627
|
+
def expire_one():
|
|
628
|
+
obj_id = current_objects_for_expiry[y, x]
|
|
629
|
+
exp_key = jax.random.fold_in(expiry_key, y * self.size[0] + x)
|
|
630
|
+
exp_delay = jax.lax.switch(
|
|
631
|
+
obj_id, self.expiry_regen_delay_fns, state.time, exp_key
|
|
632
|
+
)
|
|
633
|
+
timer_countdown = jax.lax.cond(
|
|
634
|
+
exp_delay == jnp.iinfo(jnp.int32).max,
|
|
635
|
+
lambda: 0, # No timer (permanent removal)
|
|
636
|
+
lambda: exp_delay + 1, # Timer countdown
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
# Use unified timer placement method
|
|
640
|
+
rand_key = jax.random.split(exp_key)[1]
|
|
641
|
+
new_obj_state = self._place_timer(
|
|
642
|
+
obj_state,
|
|
643
|
+
y,
|
|
644
|
+
x,
|
|
645
|
+
obj_id,
|
|
646
|
+
timer_countdown,
|
|
647
|
+
self.object_random_respawn[obj_id],
|
|
648
|
+
rand_key,
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
# Clear color grid when object expires
|
|
652
|
+
empty_color = jnp.full((3,), 255, dtype=jnp.uint8)
|
|
653
|
+
new_obj_state = new_obj_state.replace(
|
|
654
|
+
color=new_obj_state.color.at[y, x].set(empty_color)
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
return new_obj_state
|
|
658
|
+
|
|
659
|
+
def no_op():
|
|
660
|
+
return obj_state
|
|
661
|
+
|
|
662
|
+
return jax.lax.cond(is_valid, expire_one, no_op), None
|
|
663
|
+
|
|
664
|
+
new_object_state, _ = jax.lax.scan(
|
|
665
|
+
process_one_expiry,
|
|
666
|
+
object_state,
|
|
667
|
+
jnp.arange(max_objects),
|
|
668
|
+
)
|
|
669
|
+
return key_local, new_object_state
|
|
670
|
+
|
|
671
|
+
def no_expiries():
|
|
672
|
+
return key, object_state
|
|
673
|
+
|
|
674
|
+
key, object_state = jax.lax.cond(
|
|
675
|
+
num_expiring > 0,
|
|
676
|
+
process_expiries,
|
|
677
|
+
no_expiries,
|
|
678
|
+
)
|
|
679
|
+
|
|
680
|
+
return key, object_state
|
|
681
|
+
|
|
682
|
+
def _check_and_respawn_biomes(
|
|
683
|
+
self,
|
|
684
|
+
object_state: ObjectState,
|
|
685
|
+
biome_state: BiomeState,
|
|
686
|
+
current_time: int,
|
|
687
|
+
key: jax.Array,
|
|
688
|
+
) -> Tuple[ObjectState, BiomeState, jax.Array]:
|
|
689
|
+
"""Check all biomes for consumption threshold and respawn if needed."""
|
|
690
|
+
|
|
691
|
+
num_biomes = self.biome_object_frequencies.shape[0]
|
|
692
|
+
|
|
693
|
+
# Compute consumption rates for all biomes
|
|
694
|
+
consumption_rates = biome_state.consumption_count / jnp.maximum(
|
|
695
|
+
1.0, biome_state.total_objects.astype(float)
|
|
696
|
+
)
|
|
697
|
+
should_respawn = consumption_rates >= self.biome_consumption_threshold
|
|
698
|
+
|
|
699
|
+
# Split key for all biomes in parallel
|
|
700
|
+
key, subkey = jax.random.split(key)
|
|
701
|
+
biome_keys = jax.random.split(subkey, num_biomes)
|
|
702
|
+
|
|
703
|
+
# Compute all new spawns in parallel using vmap for random, switch for deterministic
|
|
704
|
+
if self.deterministic_spawn:
|
|
705
|
+
# Use switch to dispatch to concrete biome spawns for deterministic
|
|
706
|
+
def make_spawn_fn(biome_idx):
|
|
707
|
+
def spawn_fn(key):
|
|
708
|
+
return self._spawn_biome_objects(biome_idx, key, deterministic=True)
|
|
709
|
+
|
|
710
|
+
return spawn_fn
|
|
711
|
+
|
|
712
|
+
spawn_fns = [make_spawn_fn(idx) for idx in range(num_biomes)]
|
|
713
|
+
|
|
714
|
+
# Apply switch for each biome
|
|
715
|
+
all_new_objects_list = []
|
|
716
|
+
all_new_colors_list = []
|
|
717
|
+
all_new_params_list = []
|
|
718
|
+
for i in range(num_biomes):
|
|
719
|
+
obj, col, par = jax.lax.switch(i, spawn_fns, biome_keys[i])
|
|
720
|
+
all_new_objects_list.append(obj)
|
|
721
|
+
all_new_colors_list.append(col)
|
|
722
|
+
all_new_params_list.append(par)
|
|
723
|
+
|
|
724
|
+
all_new_objects = jnp.stack(all_new_objects_list)
|
|
725
|
+
all_new_colors = jnp.stack(all_new_colors_list)
|
|
726
|
+
all_new_params = jnp.stack(all_new_params_list)
|
|
727
|
+
else:
|
|
728
|
+
# Random spawn works with vmap
|
|
729
|
+
all_new_objects, all_new_colors, all_new_params = jax.vmap(
|
|
730
|
+
lambda i, k: self._spawn_biome_objects(i, k, deterministic=False)
|
|
731
|
+
)(jnp.arange(num_biomes), biome_keys)
|
|
732
|
+
|
|
733
|
+
# Initialize updated grids
|
|
734
|
+
new_obj_id = object_state.object_id
|
|
735
|
+
new_color = object_state.color
|
|
736
|
+
new_params = object_state.state_params
|
|
737
|
+
new_spawn = object_state.spawn_time
|
|
738
|
+
new_gen = object_state.generation
|
|
739
|
+
|
|
740
|
+
# Update biome state
|
|
741
|
+
new_consumption_count = jnp.where(
|
|
742
|
+
should_respawn, 0, biome_state.consumption_count
|
|
743
|
+
)
|
|
744
|
+
new_generation = biome_state.generation + should_respawn.astype(int)
|
|
745
|
+
|
|
746
|
+
# Compute new total objects for respawning biomes
|
|
747
|
+
def count_objects(i):
|
|
748
|
+
return jnp.sum((all_new_objects[i] > 0) & self.biome_masks_array[i])
|
|
749
|
+
|
|
750
|
+
new_object_counts = jax.vmap(count_objects)(jnp.arange(num_biomes))
|
|
751
|
+
new_total_objects = jnp.where(
|
|
752
|
+
should_respawn, new_object_counts, biome_state.total_objects
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
new_biome_state = BiomeState(
|
|
756
|
+
consumption_count=new_consumption_count,
|
|
757
|
+
total_objects=new_total_objects,
|
|
758
|
+
generation=new_generation,
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
# Update grids for respawning biomes
|
|
762
|
+
for i in range(num_biomes):
|
|
763
|
+
biome_mask = self.biome_masks_array[i]
|
|
764
|
+
new_gen_value = new_biome_state.generation[i]
|
|
765
|
+
|
|
766
|
+
# Only update where new spawn has objects and biome should respawn
|
|
767
|
+
is_new_object = (
|
|
768
|
+
(all_new_objects[i] > 0) & biome_mask & should_respawn[i][..., None]
|
|
769
|
+
)
|
|
770
|
+
|
|
771
|
+
new_obj_id = jnp.where(is_new_object, all_new_objects[i], new_obj_id)
|
|
772
|
+
new_color = jnp.where(
|
|
773
|
+
is_new_object[..., None], all_new_colors[i], new_color
|
|
774
|
+
)
|
|
775
|
+
new_params = jnp.where(
|
|
776
|
+
is_new_object[..., None], all_new_params[i], new_params
|
|
777
|
+
)
|
|
778
|
+
new_gen = jnp.where(is_new_object, new_gen_value, new_gen)
|
|
779
|
+
new_spawn = jnp.where(is_new_object, current_time, new_spawn)
|
|
780
|
+
|
|
781
|
+
# Clear timers in respawning biomes
|
|
782
|
+
new_respawn_timer = object_state.respawn_timer
|
|
783
|
+
new_respawn_object_id = object_state.respawn_object_id
|
|
784
|
+
for i in range(num_biomes):
|
|
785
|
+
biome_mask = self.biome_masks_array[i]
|
|
786
|
+
should_clear = biome_mask & should_respawn[i][..., None]
|
|
787
|
+
new_respawn_timer = jnp.where(should_clear, 0, new_respawn_timer)
|
|
788
|
+
new_respawn_object_id = jnp.where(should_clear, 0, new_respawn_object_id)
|
|
789
|
+
|
|
790
|
+
object_state = object_state.replace(
|
|
791
|
+
object_id=new_obj_id,
|
|
792
|
+
respawn_timer=new_respawn_timer,
|
|
793
|
+
respawn_object_id=new_respawn_object_id,
|
|
794
|
+
color=new_color,
|
|
795
|
+
state_params=new_params,
|
|
796
|
+
generation=new_gen,
|
|
797
|
+
spawn_time=new_spawn,
|
|
798
|
+
)
|
|
799
|
+
|
|
800
|
+
return object_state, new_biome_state, key
|
|
801
|
+
|
|
359
802
|
def reset_env(
|
|
360
803
|
self, key: jax.Array, params: EnvParams
|
|
361
804
|
) -> Tuple[jax.Array, EnvState]:
|
|
362
805
|
"""Reset environment state."""
|
|
363
|
-
|
|
364
|
-
|
|
806
|
+
num_object_params = 2 + 2 * self.num_fourier_terms
|
|
807
|
+
object_state = ObjectState.create_empty(self.size, num_object_params)
|
|
808
|
+
|
|
365
809
|
key, iter_key = jax.random.split(key)
|
|
810
|
+
|
|
811
|
+
# Spawn objects in each biome using unified method
|
|
366
812
|
for i in range(self.biome_object_frequencies.shape[0]):
|
|
367
813
|
iter_key, biome_key = jax.random.split(iter_key)
|
|
368
814
|
mask = self.biome_masks[i]
|
|
369
|
-
biome_grid = jnp.where(mask, i, biome_grid)
|
|
370
815
|
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
816
|
+
# Set biome_id
|
|
817
|
+
object_state = object_state.replace(
|
|
818
|
+
biome_id=jnp.where(mask, i, object_state.biome_id)
|
|
819
|
+
)
|
|
820
|
+
|
|
821
|
+
# Use unified spawn method
|
|
822
|
+
biome_objects, biome_colors, biome_object_params = (
|
|
823
|
+
self._spawn_biome_objects(i, biome_key, self.deterministic_spawn)
|
|
824
|
+
)
|
|
825
|
+
|
|
826
|
+
# Merge biome objects/colors/params into object_state
|
|
827
|
+
object_state = object_state.replace(
|
|
828
|
+
object_id=jnp.where(mask, biome_objects, object_state.object_id),
|
|
829
|
+
color=jnp.where(mask[..., None], biome_colors, object_state.color),
|
|
830
|
+
state_params=jnp.where(
|
|
831
|
+
mask[..., None], biome_object_params, object_state.state_params
|
|
832
|
+
),
|
|
833
|
+
)
|
|
377
834
|
|
|
378
835
|
# Place agent in the center of the world
|
|
379
836
|
agent_pos = jnp.array([self.size[0] // 2, self.size[1] // 2])
|
|
380
837
|
|
|
838
|
+
# Initialize biome consumption tracking
|
|
839
|
+
num_biomes = self.biome_object_frequencies.shape[0]
|
|
840
|
+
biome_consumption_count = jnp.zeros(num_biomes, dtype=int)
|
|
841
|
+
biome_total_objects = jnp.zeros(num_biomes, dtype=int)
|
|
842
|
+
|
|
843
|
+
# Count objects in each biome
|
|
844
|
+
for i in range(num_biomes):
|
|
845
|
+
mask = self.biome_masks[i]
|
|
846
|
+
# Count non-empty objects (object_id > 0)
|
|
847
|
+
total = jnp.sum((object_state.object_id > 0) & mask)
|
|
848
|
+
biome_total_objects = biome_total_objects.at[i].set(total)
|
|
849
|
+
|
|
850
|
+
biome_generation = jnp.zeros(num_biomes, dtype=int)
|
|
851
|
+
|
|
381
852
|
state = EnvState(
|
|
382
853
|
pos=agent_pos,
|
|
383
|
-
object_grid=object_grid,
|
|
384
|
-
biome_grid=biome_grid,
|
|
385
854
|
time=0,
|
|
386
855
|
digestion_buffer=jnp.zeros((self.max_reward_delay,)),
|
|
856
|
+
object_state=object_state,
|
|
857
|
+
biome_state=BiomeState(
|
|
858
|
+
consumption_count=biome_consumption_count,
|
|
859
|
+
total_objects=biome_total_objects,
|
|
860
|
+
generation=biome_generation,
|
|
861
|
+
),
|
|
387
862
|
)
|
|
388
863
|
|
|
389
864
|
return self.get_obs(state, params), state
|
|
390
865
|
|
|
391
|
-
def
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
866
|
+
def _spawn_biome_objects(
|
|
867
|
+
self,
|
|
868
|
+
biome_idx: int,
|
|
869
|
+
key: jax.Array,
|
|
870
|
+
deterministic: bool = False,
|
|
871
|
+
) -> Tuple[jax.Array, jax.Array, jax.Array]:
|
|
872
|
+
"""Spawn objects in a biome.
|
|
873
|
+
|
|
874
|
+
Returns:
|
|
875
|
+
object_grid: (H, W) array of object IDs
|
|
876
|
+
color_grid: (H, W, 3) array of RGB colors
|
|
877
|
+
state_grid: (H, W, num_state_params) array of object state parameters
|
|
878
|
+
"""
|
|
879
|
+
biome_freqs = self.biome_object_frequencies[biome_idx]
|
|
880
|
+
biome_mask = self.biome_masks_array[biome_idx]
|
|
881
|
+
|
|
882
|
+
key, spawn_key, color_key, params_key = jax.random.split(key, 4)
|
|
883
|
+
|
|
884
|
+
# Generate object IDs using deterministic or random spawn
|
|
885
|
+
if deterministic:
|
|
886
|
+
# Deterministic spawn: exact number of each object type
|
|
887
|
+
# NOTE: Requires concrete biome_idx to compute size at trace time
|
|
888
|
+
# Get static biome bounds
|
|
889
|
+
biome_start = self.biome_starts[biome_idx]
|
|
890
|
+
biome_stop = self.biome_stops[biome_idx]
|
|
891
|
+
biome_height = biome_stop[1] - biome_start[1]
|
|
892
|
+
biome_width = biome_stop[0] - biome_start[0]
|
|
893
|
+
biome_size = int(self.biome_sizes[biome_idx])
|
|
894
|
+
|
|
895
|
+
grid = jnp.linspace(0, 1, biome_size, endpoint=False)
|
|
896
|
+
biome_objects_flat = len(biome_freqs) - jnp.searchsorted(
|
|
897
|
+
jnp.cumsum(biome_freqs[::-1]), grid, side="right"
|
|
898
|
+
)
|
|
899
|
+
biome_objects_flat = jax.random.permutation(spawn_key, biome_objects_flat)
|
|
900
|
+
|
|
901
|
+
# Reshape to match biome dimensions (use concrete dimensions)
|
|
902
|
+
biome_objects = biome_objects_flat.reshape(biome_height, biome_width)
|
|
903
|
+
|
|
904
|
+
# Place in full grid using slicing with static bounds
|
|
905
|
+
object_grid = jnp.zeros((self.size[1], self.size[0]), dtype=int)
|
|
906
|
+
object_grid = object_grid.at[
|
|
907
|
+
biome_start[1] : biome_stop[1], biome_start[0] : biome_stop[0]
|
|
908
|
+
].set(biome_objects)
|
|
909
|
+
else:
|
|
910
|
+
# Random spawn: probabilistic placement (works with traced biome_idx)
|
|
911
|
+
grid_rand = jax.random.uniform(spawn_key, (self.size[1], self.size[0]))
|
|
912
|
+
empty_freq = 1.0 - jnp.sum(biome_freqs)
|
|
913
|
+
all_freqs = jnp.concatenate([jnp.array([empty_freq]), biome_freqs])
|
|
914
|
+
cumulative_freqs = jnp.cumsum(
|
|
915
|
+
jnp.concatenate([jnp.array([0.0]), all_freqs])
|
|
916
|
+
)
|
|
917
|
+
object_grid = (
|
|
918
|
+
jnp.searchsorted(cumulative_freqs, grid_rand, side="right") - 1
|
|
919
|
+
)
|
|
920
|
+
|
|
921
|
+
# Initialize color grid
|
|
922
|
+
color_grid = jnp.full((self.size[1], self.size[0], 3), 255, dtype=jnp.uint8)
|
|
923
|
+
|
|
924
|
+
# Sample ONE color per object type in this biome (not per instance)
|
|
925
|
+
# This gives objects of the same type the same color within a biome generation
|
|
926
|
+
# Skip index 0 (EMPTY object) - only sample colors for actual objects
|
|
927
|
+
num_object_types = len(self.objects)
|
|
928
|
+
num_actual_objects = num_object_types - 1 # Exclude EMPTY
|
|
929
|
+
|
|
930
|
+
if num_actual_objects > 0:
|
|
931
|
+
biome_object_colors = jax.random.randint(
|
|
932
|
+
color_key,
|
|
933
|
+
(num_actual_objects, 3),
|
|
934
|
+
minval=0,
|
|
935
|
+
maxval=256,
|
|
936
|
+
dtype=jnp.uint8,
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
# Assign colors based on object type (starting from index 1)
|
|
940
|
+
for obj_idx in range(1, num_object_types):
|
|
941
|
+
obj_mask = (object_grid == obj_idx) & biome_mask
|
|
942
|
+
obj_color = biome_object_colors[
|
|
943
|
+
obj_idx - 1
|
|
944
|
+
] # Offset by 1 since we skip EMPTY
|
|
945
|
+
color_grid = jnp.where(obj_mask[..., None], obj_color, color_grid)
|
|
946
|
+
|
|
947
|
+
# Initialize parameters grid
|
|
948
|
+
num_object_params = 2 + 2 * self.num_fourier_terms
|
|
949
|
+
params_grid = jnp.zeros(
|
|
950
|
+
(self.size[1], self.size[0], num_object_params), dtype=jnp.float32
|
|
405
951
|
)
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
952
|
+
|
|
953
|
+
# Generate per-object parameters for each object type
|
|
954
|
+
for obj_idx in range(num_object_types):
|
|
955
|
+
# Get params for this object type - this happens at trace time
|
|
956
|
+
params_key, obj_key = jax.random.split(params_key)
|
|
957
|
+
obj_params = self.objects[obj_idx].get_state(obj_key)
|
|
958
|
+
|
|
959
|
+
# Skip if no params (e.g., for EMPTY or default objects)
|
|
960
|
+
if len(obj_params) == 0:
|
|
961
|
+
continue
|
|
962
|
+
|
|
963
|
+
# Ensure params match expected size
|
|
964
|
+
if len(obj_params) != num_object_params:
|
|
965
|
+
if len(obj_params) < num_object_params:
|
|
966
|
+
obj_params = jnp.pad(
|
|
967
|
+
obj_params,
|
|
968
|
+
(0, num_object_params - len(obj_params)),
|
|
969
|
+
constant_values=0.0,
|
|
970
|
+
)
|
|
971
|
+
else:
|
|
972
|
+
# Truncate if too long
|
|
973
|
+
obj_params = obj_params[:num_object_params]
|
|
974
|
+
|
|
975
|
+
# Assign to all objects of this type in this biome
|
|
976
|
+
obj_mask = (object_grid == obj_idx) & biome_mask
|
|
977
|
+
params_grid = jnp.where(obj_mask[..., None], obj_params, params_grid)
|
|
978
|
+
|
|
979
|
+
return object_grid, color_grid, params_grid
|
|
409
980
|
|
|
410
981
|
def is_terminal(self, state: EnvState, params: EnvParams) -> jax.Array:
|
|
411
982
|
"""Foragax is a continuing environment."""
|
|
@@ -428,21 +999,10 @@ class ForagaxEnv(environment.Environment):
|
|
|
428
999
|
|
|
429
1000
|
def state_space(self, params: EnvParams) -> spaces.Dict:
|
|
430
1001
|
"""State space of the environment."""
|
|
1002
|
+
num_object_params = 2 + 2 * self.num_fourier_terms
|
|
431
1003
|
return spaces.Dict(
|
|
432
1004
|
{
|
|
433
1005
|
"pos": spaces.Box(0, max(self.size), (2,), int),
|
|
434
|
-
"object_grid": spaces.Box(
|
|
435
|
-
-1000 * len(self.object_ids),
|
|
436
|
-
len(self.object_ids),
|
|
437
|
-
(self.size[1], self.size[0]),
|
|
438
|
-
int,
|
|
439
|
-
),
|
|
440
|
-
"biome_grid": spaces.Box(
|
|
441
|
-
0,
|
|
442
|
-
self.biome_object_frequencies.shape[0],
|
|
443
|
-
(self.size[1], self.size[0]),
|
|
444
|
-
int,
|
|
445
|
-
),
|
|
446
1006
|
"time": spaces.Discrete(params.max_steps_in_episode),
|
|
447
1007
|
"digestion_buffer": spaces.Box(
|
|
448
1008
|
-jnp.inf,
|
|
@@ -450,11 +1010,79 @@ class ForagaxEnv(environment.Environment):
|
|
|
450
1010
|
(self.max_reward_delay,),
|
|
451
1011
|
float,
|
|
452
1012
|
),
|
|
1013
|
+
"object_state": spaces.Dict(
|
|
1014
|
+
{
|
|
1015
|
+
"object_id": spaces.Box(
|
|
1016
|
+
-1000 * len(self.object_ids),
|
|
1017
|
+
len(self.object_ids),
|
|
1018
|
+
(self.size[1], self.size[0]),
|
|
1019
|
+
int,
|
|
1020
|
+
),
|
|
1021
|
+
"spawn_time": spaces.Box(
|
|
1022
|
+
0,
|
|
1023
|
+
jnp.inf,
|
|
1024
|
+
(self.size[1], self.size[0]),
|
|
1025
|
+
int,
|
|
1026
|
+
),
|
|
1027
|
+
"color": spaces.Box(
|
|
1028
|
+
0,
|
|
1029
|
+
255,
|
|
1030
|
+
(self.size[1], self.size[0], 3),
|
|
1031
|
+
int,
|
|
1032
|
+
),
|
|
1033
|
+
"generation": spaces.Box(
|
|
1034
|
+
0,
|
|
1035
|
+
jnp.inf,
|
|
1036
|
+
(self.size[1], self.size[0]),
|
|
1037
|
+
int,
|
|
1038
|
+
),
|
|
1039
|
+
"state_params": spaces.Box(
|
|
1040
|
+
-jnp.inf,
|
|
1041
|
+
jnp.inf,
|
|
1042
|
+
(self.size[1], self.size[0], num_object_params),
|
|
1043
|
+
float,
|
|
1044
|
+
),
|
|
1045
|
+
"biome_id": spaces.Box(
|
|
1046
|
+
-1,
|
|
1047
|
+
self.biome_object_frequencies.shape[0],
|
|
1048
|
+
(self.size[1], self.size[0]),
|
|
1049
|
+
int,
|
|
1050
|
+
),
|
|
1051
|
+
}
|
|
1052
|
+
),
|
|
1053
|
+
"biome_state": spaces.Dict(
|
|
1054
|
+
{
|
|
1055
|
+
"consumption_count": spaces.Box(
|
|
1056
|
+
0,
|
|
1057
|
+
jnp.inf,
|
|
1058
|
+
(self.biome_object_frequencies.shape[0],),
|
|
1059
|
+
int,
|
|
1060
|
+
),
|
|
1061
|
+
"total_objects": spaces.Box(
|
|
1062
|
+
0,
|
|
1063
|
+
jnp.inf,
|
|
1064
|
+
(self.biome_object_frequencies.shape[0],),
|
|
1065
|
+
int,
|
|
1066
|
+
),
|
|
1067
|
+
"generation": spaces.Box(
|
|
1068
|
+
0,
|
|
1069
|
+
jnp.inf,
|
|
1070
|
+
(self.biome_object_frequencies.shape[0],),
|
|
1071
|
+
int,
|
|
1072
|
+
),
|
|
1073
|
+
}
|
|
1074
|
+
),
|
|
453
1075
|
}
|
|
454
1076
|
)
|
|
455
1077
|
|
|
456
|
-
def
|
|
457
|
-
|
|
1078
|
+
def _compute_aperture_coordinates(
|
|
1079
|
+
self, pos: jax.Array
|
|
1080
|
+
) -> Tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
|
|
1081
|
+
"""Compute aperture coordinates for the given position.
|
|
1082
|
+
|
|
1083
|
+
Returns:
|
|
1084
|
+
(y_coords, x_coords, y_coords_clamped/mod, x_coords_clamped/mod)
|
|
1085
|
+
"""
|
|
458
1086
|
ap_h, ap_w = self.aperture_size
|
|
459
1087
|
start_y = pos[1] - ap_h // 2
|
|
460
1088
|
start_x = pos[0] - ap_w // 2
|
|
@@ -465,27 +1093,37 @@ class ForagaxEnv(environment.Environment):
|
|
|
465
1093
|
x_coords = start_x + x_offsets
|
|
466
1094
|
|
|
467
1095
|
if self.nowrap:
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
1096
|
+
y_coords_adj = jnp.clip(y_coords, 0, self.size[1] - 1)
|
|
1097
|
+
x_coords_adj = jnp.clip(x_coords, 0, self.size[0] - 1)
|
|
1098
|
+
else:
|
|
1099
|
+
y_coords_adj = jnp.mod(y_coords, self.size[1])
|
|
1100
|
+
x_coords_adj = jnp.mod(x_coords, self.size[0])
|
|
1101
|
+
|
|
1102
|
+
return y_coords, x_coords, y_coords_adj, x_coords_adj
|
|
1103
|
+
|
|
1104
|
+
def _get_aperture(self, object_id_grid: jax.Array, pos: jax.Array) -> jax.Array:
|
|
1105
|
+
"""Extract the aperture view from the object id grid."""
|
|
1106
|
+
y_coords, x_coords, y_coords_adj, x_coords_adj = (
|
|
1107
|
+
self._compute_aperture_coordinates(pos)
|
|
1108
|
+
)
|
|
1109
|
+
|
|
1110
|
+
values = object_id_grid[y_coords_adj, x_coords_adj]
|
|
1111
|
+
|
|
1112
|
+
if self.nowrap:
|
|
1113
|
+
# Mark out-of-bounds positions with padding
|
|
473
1114
|
y_out = (y_coords < 0) | (y_coords >= self.size[1])
|
|
474
1115
|
x_out = (x_coords < 0) | (x_coords >= self.size[0])
|
|
475
1116
|
out_of_bounds = y_out | x_out
|
|
476
1117
|
padding_index = self.object_ids[-1]
|
|
477
1118
|
aperture = jnp.where(out_of_bounds, padding_index, values)
|
|
478
1119
|
else:
|
|
479
|
-
|
|
480
|
-
x_coords_mod = jnp.mod(x_coords, self.size[0])
|
|
481
|
-
aperture = object_grid[y_coords_mod, x_coords_mod]
|
|
1120
|
+
aperture = values
|
|
482
1121
|
|
|
483
1122
|
return aperture
|
|
484
1123
|
|
|
485
1124
|
def get_obs(self, state: EnvState, params: EnvParams, key=None) -> jax.Array:
|
|
486
1125
|
"""Get observation based on observation_type and full_world."""
|
|
487
|
-
|
|
488
|
-
obs_grid = jnp.maximum(0, state.object_grid)
|
|
1126
|
+
obs_grid = state.object_state.object_id
|
|
489
1127
|
|
|
490
1128
|
if self.full_world:
|
|
491
1129
|
return self._get_world_obs(obs_grid, state)
|
|
@@ -587,48 +1225,43 @@ class ForagaxEnv(environment.Environment):
|
|
|
587
1225
|
|
|
588
1226
|
if is_world_mode:
|
|
589
1227
|
# Create an RGB image from the object grid
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
1228
|
+
# Use stateful object colors if dynamic_biomes is enabled, else use default colors
|
|
1229
|
+
if self.dynamic_biomes:
|
|
1230
|
+
# Use per-instance colors from state
|
|
1231
|
+
img = state.object_state.color.copy()
|
|
1232
|
+
else:
|
|
1233
|
+
# Use default object colors
|
|
1234
|
+
img = jnp.zeros((self.size[1], self.size[0], 3))
|
|
1235
|
+
render_grid = state.object_state.object_id
|
|
593
1236
|
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
1237
|
+
def update_image(i, img):
|
|
1238
|
+
color = self.object_colors[i]
|
|
1239
|
+
mask = render_grid == i
|
|
1240
|
+
img = jnp.where(mask[..., None], color, img)
|
|
1241
|
+
return img
|
|
599
1242
|
|
|
600
|
-
|
|
1243
|
+
img = jax.lax.fori_loop(0, len(self.object_ids), update_image, img)
|
|
601
1244
|
|
|
602
1245
|
# Tint the agent's aperture
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
1246
|
+
y_coords, x_coords, y_coords_adj, x_coords_adj = (
|
|
1247
|
+
self._compute_aperture_coordinates(state.pos)
|
|
1248
|
+
)
|
|
606
1249
|
|
|
607
1250
|
alpha = 0.2
|
|
608
1251
|
agent_color = jnp.array(AGENT.color)
|
|
609
1252
|
|
|
610
|
-
# Create indices for the aperture
|
|
611
|
-
y_offsets = jnp.arange(ap_h)
|
|
612
|
-
x_offsets = jnp.arange(ap_w)
|
|
613
|
-
y_coords_original = start_y + y_offsets[:, None]
|
|
614
|
-
x_coords_original = start_x + x_offsets
|
|
615
|
-
|
|
616
1253
|
if self.nowrap:
|
|
617
|
-
y_coords = jnp.clip(y_coords_original, 0, self.size[1] - 1)
|
|
618
|
-
x_coords = jnp.clip(x_coords_original, 0, self.size[0] - 1)
|
|
619
1254
|
# Create tint mask: any in-bounds original position maps to a cell makes it tinted
|
|
620
1255
|
tint_mask = jnp.zeros((self.size[1], self.size[0]), dtype=int)
|
|
621
|
-
tint_mask = tint_mask.at[
|
|
1256
|
+
tint_mask = tint_mask.at[y_coords_adj, x_coords_adj].set(1)
|
|
622
1257
|
# Apply tint to masked positions
|
|
623
1258
|
original_colors = img
|
|
624
1259
|
tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
|
|
625
1260
|
img = jnp.where(tint_mask[..., None], tinted_colors, img)
|
|
626
1261
|
else:
|
|
627
|
-
|
|
628
|
-
x_coords = jnp.mod(x_coords_original, self.size[0])
|
|
629
|
-
original_colors = img[y_coords, x_coords]
|
|
1262
|
+
original_colors = img[y_coords_adj, x_coords_adj]
|
|
630
1263
|
tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
|
|
631
|
-
img = img.at[
|
|
1264
|
+
img = img.at[y_coords_adj, x_coords_adj].set(tinted_colors)
|
|
632
1265
|
|
|
633
1266
|
# Agent color
|
|
634
1267
|
img = img.at[state.pos[1], state.pos[0]].set(jnp.array(AGENT.color))
|
|
@@ -641,6 +1274,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
641
1274
|
|
|
642
1275
|
if is_true_mode:
|
|
643
1276
|
# Apply true object borders by overlaying true colors on border pixels
|
|
1277
|
+
render_grid = state.object_state.object_id
|
|
644
1278
|
img = apply_true_borders(
|
|
645
1279
|
img, render_grid, self.size, len(self.object_ids)
|
|
646
1280
|
)
|
|
@@ -653,10 +1287,27 @@ class ForagaxEnv(environment.Environment):
|
|
|
653
1287
|
img = img.at[:, col_indices].set(grid_color)
|
|
654
1288
|
|
|
655
1289
|
elif is_aperture_mode:
|
|
656
|
-
obs_grid =
|
|
1290
|
+
obs_grid = state.object_state.object_id
|
|
657
1291
|
aperture = self._get_aperture(obs_grid, state.pos)
|
|
658
|
-
|
|
659
|
-
|
|
1292
|
+
|
|
1293
|
+
if self.dynamic_biomes:
|
|
1294
|
+
# Use per-instance colors from state - extract aperture view
|
|
1295
|
+
y_coords, x_coords, y_coords_adj, x_coords_adj = (
|
|
1296
|
+
self._compute_aperture_coordinates(state.pos)
|
|
1297
|
+
)
|
|
1298
|
+
img = state.object_state.color[y_coords_adj, x_coords_adj]
|
|
1299
|
+
|
|
1300
|
+
if self.nowrap:
|
|
1301
|
+
# For out-of-bounds, use padding object color
|
|
1302
|
+
y_out = (y_coords < 0) | (y_coords >= self.size[1])
|
|
1303
|
+
x_out = (x_coords < 0) | (x_coords >= self.size[0])
|
|
1304
|
+
out_of_bounds = y_out | x_out
|
|
1305
|
+
padding_color = jnp.array(self.objects[-1].color, dtype=jnp.float32)
|
|
1306
|
+
img = jnp.where(out_of_bounds[..., None], padding_color, img)
|
|
1307
|
+
else:
|
|
1308
|
+
# Use default object colors
|
|
1309
|
+
aperture_one_hot = jax.nn.one_hot(aperture, len(self.object_ids))
|
|
1310
|
+
img = jnp.tensordot(aperture_one_hot, self.object_colors, axes=1)
|
|
660
1311
|
|
|
661
1312
|
# Draw agent in the center
|
|
662
1313
|
center_y, center_x = self.aperture_size[1] // 2, self.aperture_size[0] // 2
|