continual-foragax 0.33.2__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {continual_foragax-0.33.2.dist-info → continual_foragax-0.35.0.dist-info}/METADATA +1 -1
- {continual_foragax-0.33.2.dist-info → continual_foragax-0.35.0.dist-info}/RECORD +8 -8
- foragax/env.py +296 -36
- foragax/objects.py +9 -1
- foragax/registry.py +21 -0
- {continual_foragax-0.33.2.dist-info → continual_foragax-0.35.0.dist-info}/WHEEL +0 -0
- {continual_foragax-0.33.2.dist-info → continual_foragax-0.35.0.dist-info}/entry_points.txt +0 -0
- {continual_foragax-0.33.2.dist-info → continual_foragax-0.35.0.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
|
|
3
|
-
foragax/env.py,sha256=
|
|
4
|
-
foragax/objects.py,sha256=
|
|
5
|
-
foragax/registry.py,sha256=
|
|
3
|
+
foragax/env.py,sha256=K3noPwdYmQlnXVjslqVzX_FIB-CnOh37mWFArQXnf_Y,66324
|
|
4
|
+
foragax/objects.py,sha256=PPuLYjD7em7GL404eSpP6q8TxF8p7JtQ1kIwh7uD_tU,26860
|
|
5
|
+
foragax/registry.py,sha256=Ph_Z3O5GpIjrgvbKL-8Iq-Kc6MqfZIsF9KDDDzm7N3o,18787
|
|
6
6
|
foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
|
|
7
7
|
foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
|
|
8
8
|
foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt,sha256=N7URbX6VlCZvCboUogYjMzy1I-0cfNPOn0QTLSHHfQ0,1776751
|
|
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
|
|
|
128
128
|
foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
|
|
129
129
|
foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
|
|
130
130
|
foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
|
|
131
|
-
continual_foragax-0.
|
|
132
|
-
continual_foragax-0.
|
|
133
|
-
continual_foragax-0.
|
|
134
|
-
continual_foragax-0.
|
|
135
|
-
continual_foragax-0.
|
|
131
|
+
continual_foragax-0.35.0.dist-info/METADATA,sha256=pZ1uSNXsaYkaFaz7xvO9X7bee8Jr9RjT7nbwW3tx5ps,4713
|
|
132
|
+
continual_foragax-0.35.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
133
|
+
continual_foragax-0.35.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
|
|
134
|
+
continual_foragax-0.35.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
|
|
135
|
+
continual_foragax-0.35.0.dist-info/RECORD,,
|
foragax/env.py
CHANGED
|
@@ -690,11 +690,16 @@ class ForagaxEnv(environment.Environment):
|
|
|
690
690
|
|
|
691
691
|
num_biomes = self.biome_object_frequencies.shape[0]
|
|
692
692
|
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
693
|
+
if isinstance(self.biome_consumption_threshold, float):
|
|
694
|
+
# Compute consumption rates for all biomes
|
|
695
|
+
consumption_rates = biome_state.consumption_count / jnp.maximum(
|
|
696
|
+
1.0, biome_state.total_objects.astype(float)
|
|
697
|
+
)
|
|
698
|
+
should_respawn = consumption_rates >= self.biome_consumption_threshold
|
|
699
|
+
else:
|
|
700
|
+
should_respawn = (
|
|
701
|
+
biome_state.consumption_count >= self.biome_consumption_threshold
|
|
702
|
+
)
|
|
698
703
|
|
|
699
704
|
# Split key for all biomes in parallel
|
|
700
705
|
key, subkey = jax.random.split(key)
|
|
@@ -1235,12 +1240,91 @@ class ForagaxEnv(environment.Environment):
|
|
|
1235
1240
|
|
|
1236
1241
|
return spaces.Box(0, 1, obs_shape, float)
|
|
1237
1242
|
|
|
1243
|
+
def _compute_reward_grid(self, state: EnvState) -> jax.Array:
|
|
1244
|
+
"""Compute rewards for all grid positions.
|
|
1245
|
+
|
|
1246
|
+
Returns:
|
|
1247
|
+
Array of shape (H, W) with reward values for each cell
|
|
1248
|
+
"""
|
|
1249
|
+
fixed_key = jax.random.key(0) # Fixed key for deterministic reward computation
|
|
1250
|
+
|
|
1251
|
+
def compute_reward(obj_id, params):
|
|
1252
|
+
return jax.lax.cond(
|
|
1253
|
+
obj_id > 0,
|
|
1254
|
+
lambda: jax.lax.switch(
|
|
1255
|
+
obj_id, self.reward_fns, state.time, fixed_key, params
|
|
1256
|
+
),
|
|
1257
|
+
lambda: 0.0,
|
|
1258
|
+
)
|
|
1259
|
+
|
|
1260
|
+
reward_grid = jax.vmap(jax.vmap(compute_reward))(
|
|
1261
|
+
state.object_state.object_id, state.object_state.state_params
|
|
1262
|
+
)
|
|
1263
|
+
return reward_grid
|
|
1264
|
+
|
|
1265
|
+
def _reward_to_color(self, reward: jax.Array) -> jax.Array:
|
|
1266
|
+
"""Convert reward value to RGB color using diverging gradient.
|
|
1267
|
+
|
|
1268
|
+
Args:
|
|
1269
|
+
reward: Reward value (typically -1 to +1)
|
|
1270
|
+
|
|
1271
|
+
Returns:
|
|
1272
|
+
RGB color array with shape (..., 3) and dtype uint8
|
|
1273
|
+
"""
|
|
1274
|
+
# Diverging gradient: +1 = green (0, 255, 0), 0 = white (255, 255, 255), -1 = magenta (255, 0, 255)
|
|
1275
|
+
# Clamp reward to [-1, 1] range for color mapping
|
|
1276
|
+
reward_clamped = jnp.clip(reward, -1.0, 1.0)
|
|
1277
|
+
|
|
1278
|
+
# For positive rewards: interpolate from white to green
|
|
1279
|
+
# For negative rewards: interpolate from white to magenta
|
|
1280
|
+
# At reward = 0: white (255, 255, 255)
|
|
1281
|
+
# At reward = +1: green (0, 255, 0)
|
|
1282
|
+
# At reward = -1: magenta (255, 0, 255)
|
|
1283
|
+
|
|
1284
|
+
red_component = jnp.where(
|
|
1285
|
+
reward_clamped >= 0,
|
|
1286
|
+
(1 - reward_clamped) * 255, # Fade from white to green: 255 -> 0
|
|
1287
|
+
255, # Stay at 255 for all negative rewards
|
|
1288
|
+
)
|
|
1289
|
+
|
|
1290
|
+
green_component = jnp.where(
|
|
1291
|
+
reward_clamped >= 0,
|
|
1292
|
+
255, # Stay at 255 for all positive rewards
|
|
1293
|
+
(1 + reward_clamped) * 255, # Fade from white to magenta: 255 -> 0
|
|
1294
|
+
)
|
|
1295
|
+
|
|
1296
|
+
blue_component = jnp.where(
|
|
1297
|
+
reward_clamped >= 0,
|
|
1298
|
+
(1 - reward_clamped) * 255, # Fade from white to green: 255 -> 0
|
|
1299
|
+
255, # Stay at 255 for all negative rewards
|
|
1300
|
+
)
|
|
1301
|
+
|
|
1302
|
+
return jnp.stack(
|
|
1303
|
+
[red_component, green_component, blue_component], axis=-1
|
|
1304
|
+
).astype(jnp.uint8)
|
|
1305
|
+
|
|
1238
1306
|
@partial(jax.jit, static_argnames=("self", "render_mode"))
|
|
1239
|
-
def render(
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1307
|
+
def render(
|
|
1308
|
+
self,
|
|
1309
|
+
state: EnvState,
|
|
1310
|
+
params: EnvParams,
|
|
1311
|
+
render_mode: str = "world",
|
|
1312
|
+
):
|
|
1313
|
+
"""Render the environment state.
|
|
1314
|
+
|
|
1315
|
+
Args:
|
|
1316
|
+
state: Current environment state
|
|
1317
|
+
params: Environment parameters
|
|
1318
|
+
render_mode: One of "world", "world_true", "world_reward", "aperture", "aperture_true", "aperture_reward"
|
|
1319
|
+
"""
|
|
1320
|
+
is_world_mode = render_mode in ("world", "world_true", "world_reward")
|
|
1321
|
+
is_aperture_mode = render_mode in (
|
|
1322
|
+
"aperture",
|
|
1323
|
+
"aperture_true",
|
|
1324
|
+
"aperture_reward",
|
|
1325
|
+
)
|
|
1243
1326
|
is_true_mode = render_mode in ("world_true", "aperture_true")
|
|
1327
|
+
is_reward_mode = render_mode in ("world_reward", "aperture_reward")
|
|
1244
1328
|
|
|
1245
1329
|
if is_world_mode:
|
|
1246
1330
|
# Create an RGB image from the object grid
|
|
@@ -1265,6 +1349,29 @@ class ForagaxEnv(environment.Environment):
|
|
|
1265
1349
|
|
|
1266
1350
|
img = jax.lax.fori_loop(0, len(self.object_ids), update_image, img)
|
|
1267
1351
|
|
|
1352
|
+
if is_reward_mode:
|
|
1353
|
+
# Scale image by 3 to create space for reward visualization
|
|
1354
|
+
img = jax.image.resize(
|
|
1355
|
+
img,
|
|
1356
|
+
(self.size[1] * 3, self.size[0] * 3, 3),
|
|
1357
|
+
jax.image.ResizeMethod.NEAREST,
|
|
1358
|
+
)
|
|
1359
|
+
|
|
1360
|
+
# Compute rewards for all cells
|
|
1361
|
+
reward_grid = self._compute_reward_grid(state)
|
|
1362
|
+
|
|
1363
|
+
# Convert rewards to colors
|
|
1364
|
+
reward_colors = self._reward_to_color(reward_grid)
|
|
1365
|
+
|
|
1366
|
+
# Resize reward colors to match 3x scale and place in middle cells
|
|
1367
|
+
# We need to place reward colors at positions (i*3+1, j*3+1) for each (i,j)
|
|
1368
|
+
# Create index arrays for middle cells
|
|
1369
|
+
i_indices = jnp.arange(self.size[1])[:, None] * 3 + 1
|
|
1370
|
+
j_indices = jnp.arange(self.size[0])[None, :] * 3 + 1
|
|
1371
|
+
|
|
1372
|
+
# Broadcast and set middle cells
|
|
1373
|
+
img = img.at[i_indices, j_indices].set(reward_colors)
|
|
1374
|
+
|
|
1268
1375
|
# Tint the agent's aperture
|
|
1269
1376
|
y_coords, x_coords, y_coords_adj, x_coords_adj = (
|
|
1270
1377
|
self._compute_aperture_coordinates(state.pos)
|
|
@@ -1273,27 +1380,127 @@ class ForagaxEnv(environment.Environment):
|
|
|
1273
1380
|
alpha = 0.2
|
|
1274
1381
|
agent_color = jnp.array(AGENT.color)
|
|
1275
1382
|
|
|
1276
|
-
if
|
|
1277
|
-
#
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1383
|
+
if is_reward_mode:
|
|
1384
|
+
# For reward mode, we need to adjust coordinates for 3x scaled image
|
|
1385
|
+
if self.nowrap:
|
|
1386
|
+
# Create tint mask for 3x scaled image
|
|
1387
|
+
tint_mask = jnp.zeros(
|
|
1388
|
+
(self.size[1] * 3, self.size[0] * 3), dtype=bool
|
|
1389
|
+
)
|
|
1390
|
+
|
|
1391
|
+
# For each aperture cell, tint all 9 cells in its 3x3 block
|
|
1392
|
+
# Create meshgrid to get all aperture cell coordinates
|
|
1393
|
+
y_grid, x_grid = jnp.meshgrid(
|
|
1394
|
+
y_coords_adj.flatten(), x_coords_adj.flatten(), indexing="ij"
|
|
1395
|
+
)
|
|
1396
|
+
y_flat = y_grid.flatten()
|
|
1397
|
+
x_flat = x_grid.flatten()
|
|
1398
|
+
|
|
1399
|
+
# Create offset arrays for 3x3 blocks
|
|
1400
|
+
offsets = jnp.array(
|
|
1401
|
+
[[dy, dx] for dy in range(3) for dx in range(3)]
|
|
1402
|
+
)
|
|
1403
|
+
|
|
1404
|
+
# For each aperture cell, expand to 9 cells
|
|
1405
|
+
# We need to repeat each cell coordinate 9 times, then add offsets
|
|
1406
|
+
num_aperture_cells = y_flat.size
|
|
1407
|
+
y_base = jnp.repeat(
|
|
1408
|
+
y_flat * 3, 9
|
|
1409
|
+
) # Repeat each y coord 9 times and scale by 3
|
|
1410
|
+
x_base = jnp.repeat(
|
|
1411
|
+
x_flat * 3, 9
|
|
1412
|
+
) # Repeat each x coord 9 times and scale by 3
|
|
1413
|
+
y_offsets = jnp.tile(
|
|
1414
|
+
offsets[:, 0], num_aperture_cells
|
|
1415
|
+
) # Tile all 9 offsets
|
|
1416
|
+
x_offsets = jnp.tile(
|
|
1417
|
+
offsets[:, 1], num_aperture_cells
|
|
1418
|
+
) # Tile all 9 offsets
|
|
1419
|
+
y_expanded = y_base + y_offsets
|
|
1420
|
+
x_expanded = x_base + x_offsets
|
|
1421
|
+
|
|
1422
|
+
tint_mask = tint_mask.at[y_expanded, x_expanded].set(True)
|
|
1423
|
+
|
|
1424
|
+
original_colors = img
|
|
1425
|
+
tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
|
|
1426
|
+
img = jnp.where(tint_mask[..., None], tinted_colors, img)
|
|
1427
|
+
else:
|
|
1428
|
+
# Tint all 9 cells in each 3x3 block for aperture cells
|
|
1429
|
+
# Create meshgrid to get all aperture cell coordinates
|
|
1430
|
+
y_grid, x_grid = jnp.meshgrid(
|
|
1431
|
+
y_coords_adj.flatten(), x_coords_adj.flatten(), indexing="ij"
|
|
1432
|
+
)
|
|
1433
|
+
y_flat = y_grid.flatten()
|
|
1434
|
+
x_flat = x_grid.flatten()
|
|
1435
|
+
|
|
1436
|
+
# Create offset arrays for 3x3 blocks
|
|
1437
|
+
offsets = jnp.array(
|
|
1438
|
+
[[dy, dx] for dy in range(3) for dx in range(3)]
|
|
1439
|
+
)
|
|
1440
|
+
|
|
1441
|
+
# For each aperture cell, expand to 9 cells
|
|
1442
|
+
# We need to repeat each cell coordinate 9 times, then add offsets
|
|
1443
|
+
num_aperture_cells = y_flat.size
|
|
1444
|
+
y_base = jnp.repeat(
|
|
1445
|
+
y_flat * 3, 9
|
|
1446
|
+
) # Repeat each y coord 9 times and scale by 3
|
|
1447
|
+
x_base = jnp.repeat(
|
|
1448
|
+
x_flat * 3, 9
|
|
1449
|
+
) # Repeat each x coord 9 times and scale by 3
|
|
1450
|
+
y_offsets = jnp.tile(
|
|
1451
|
+
offsets[:, 0], num_aperture_cells
|
|
1452
|
+
) # Tile all 9 offsets
|
|
1453
|
+
x_offsets = jnp.tile(
|
|
1454
|
+
offsets[:, 1], num_aperture_cells
|
|
1455
|
+
) # Tile all 9 offsets
|
|
1456
|
+
y_expanded = y_base + y_offsets
|
|
1457
|
+
x_expanded = x_base + x_offsets
|
|
1458
|
+
|
|
1459
|
+
# Get original colors and tint them
|
|
1460
|
+
original_colors = img[y_expanded, x_expanded]
|
|
1461
|
+
tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
|
|
1462
|
+
img = img.at[y_expanded, x_expanded].set(tinted_colors)
|
|
1463
|
+
|
|
1464
|
+
# Agent color - set all 9 cells of the agent's 3x3 block
|
|
1465
|
+
agent_y, agent_x = state.pos[1], state.pos[0]
|
|
1466
|
+
agent_offsets = jnp.array(
|
|
1467
|
+
[[dy, dx] for dy in range(3) for dx in range(3)]
|
|
1468
|
+
)
|
|
1469
|
+
agent_y_cells = agent_y * 3 + agent_offsets[:, 0]
|
|
1470
|
+
agent_x_cells = agent_x * 3 + agent_offsets[:, 1]
|
|
1471
|
+
img = img.at[agent_y_cells, agent_x_cells].set(
|
|
1472
|
+
jnp.array(AGENT.color, dtype=jnp.uint8)
|
|
1473
|
+
)
|
|
1474
|
+
|
|
1475
|
+
# Scale by 8 to final size
|
|
1476
|
+
img = jax.image.resize(
|
|
1477
|
+
img,
|
|
1478
|
+
(self.size[1] * 24, self.size[0] * 24, 3),
|
|
1479
|
+
jax.image.ResizeMethod.NEAREST,
|
|
1480
|
+
)
|
|
1284
1481
|
else:
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
|
|
1482
|
+
# Standard rendering without reward visualization
|
|
1483
|
+
if self.nowrap:
|
|
1484
|
+
# Create tint mask: any in-bounds original position maps to a cell makes it tinted
|
|
1485
|
+
tint_mask = jnp.zeros((self.size[1], self.size[0]), dtype=int)
|
|
1486
|
+
tint_mask = tint_mask.at[y_coords_adj, x_coords_adj].set(1)
|
|
1487
|
+
# Apply tint to masked positions
|
|
1488
|
+
original_colors = img
|
|
1489
|
+
tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
|
|
1490
|
+
img = jnp.where(tint_mask[..., None], tinted_colors, img)
|
|
1491
|
+
else:
|
|
1492
|
+
original_colors = img[y_coords_adj, x_coords_adj]
|
|
1493
|
+
tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
|
|
1494
|
+
img = img.at[y_coords_adj, x_coords_adj].set(tinted_colors)
|
|
1288
1495
|
|
|
1289
|
-
|
|
1290
|
-
|
|
1496
|
+
# Agent color
|
|
1497
|
+
img = img.at[state.pos[1], state.pos[0]].set(jnp.array(AGENT.color))
|
|
1291
1498
|
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
|
|
1296
|
-
|
|
1499
|
+
img = jax.image.resize(
|
|
1500
|
+
img,
|
|
1501
|
+
(self.size[1] * 24, self.size[0] * 24, 3),
|
|
1502
|
+
jax.image.ResizeMethod.NEAREST,
|
|
1503
|
+
)
|
|
1297
1504
|
|
|
1298
1505
|
if is_true_mode:
|
|
1299
1506
|
# Apply true object borders by overlaying true colors on border pixels
|
|
@@ -1340,16 +1547,69 @@ class ForagaxEnv(environment.Environment):
|
|
|
1340
1547
|
aperture_one_hot = jax.nn.one_hot(aperture, len(self.object_ids))
|
|
1341
1548
|
img = jnp.tensordot(aperture_one_hot, self.object_colors, axes=1)
|
|
1342
1549
|
|
|
1343
|
-
|
|
1344
|
-
|
|
1345
|
-
|
|
1550
|
+
if is_reward_mode:
|
|
1551
|
+
# Scale image by 3 to create space for reward visualization
|
|
1552
|
+
img = img.astype(jnp.uint8)
|
|
1553
|
+
img = jax.image.resize(
|
|
1554
|
+
img,
|
|
1555
|
+
(self.aperture_size[0] * 3, self.aperture_size[1] * 3, 3),
|
|
1556
|
+
jax.image.ResizeMethod.NEAREST,
|
|
1557
|
+
)
|
|
1346
1558
|
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
|
-
|
|
1351
|
-
|
|
1352
|
-
|
|
1559
|
+
# Compute rewards for aperture region
|
|
1560
|
+
y_coords, x_coords, y_coords_adj, x_coords_adj = (
|
|
1561
|
+
self._compute_aperture_coordinates(state.pos)
|
|
1562
|
+
)
|
|
1563
|
+
|
|
1564
|
+
# Get reward grid for the full world
|
|
1565
|
+
full_reward_grid = self._compute_reward_grid(state)
|
|
1566
|
+
|
|
1567
|
+
# Extract aperture rewards
|
|
1568
|
+
aperture_rewards = full_reward_grid[y_coords_adj, x_coords_adj]
|
|
1569
|
+
|
|
1570
|
+
# Convert rewards to colors
|
|
1571
|
+
reward_colors = self._reward_to_color(aperture_rewards)
|
|
1572
|
+
|
|
1573
|
+
# Place reward colors in the middle cells (index 1 in each 3x3 block)
|
|
1574
|
+
i_indices = jnp.arange(self.aperture_size[0])[:, None] * 3 + 1
|
|
1575
|
+
j_indices = jnp.arange(self.aperture_size[1])[None, :] * 3 + 1
|
|
1576
|
+
img = img.at[i_indices, j_indices].set(reward_colors)
|
|
1577
|
+
|
|
1578
|
+
# Draw agent in the center (all 9 cells of the 3x3 block)
|
|
1579
|
+
center_y, center_x = (
|
|
1580
|
+
self.aperture_size[1] // 2,
|
|
1581
|
+
self.aperture_size[0] // 2,
|
|
1582
|
+
)
|
|
1583
|
+
agent_offsets = jnp.array(
|
|
1584
|
+
[[dy, dx] for dy in range(3) for dx in range(3)]
|
|
1585
|
+
)
|
|
1586
|
+
agent_y_cells = center_y * 3 + agent_offsets[:, 0]
|
|
1587
|
+
agent_x_cells = center_x * 3 + agent_offsets[:, 1]
|
|
1588
|
+
img = img.at[agent_y_cells, agent_x_cells].set(
|
|
1589
|
+
jnp.array(AGENT.color, dtype=jnp.uint8)
|
|
1590
|
+
)
|
|
1591
|
+
|
|
1592
|
+
# Scale by 8 to final size
|
|
1593
|
+
img = jax.image.resize(
|
|
1594
|
+
img,
|
|
1595
|
+
(self.aperture_size[0] * 24, self.aperture_size[1] * 24, 3),
|
|
1596
|
+
jax.image.ResizeMethod.NEAREST,
|
|
1597
|
+
)
|
|
1598
|
+
else:
|
|
1599
|
+
# Standard rendering without reward visualization
|
|
1600
|
+
# Draw agent in the center
|
|
1601
|
+
center_y, center_x = (
|
|
1602
|
+
self.aperture_size[1] // 2,
|
|
1603
|
+
self.aperture_size[0] // 2,
|
|
1604
|
+
)
|
|
1605
|
+
img = img.at[center_y, center_x].set(jnp.array(AGENT.color))
|
|
1606
|
+
|
|
1607
|
+
img = img.astype(jnp.uint8)
|
|
1608
|
+
img = jax.image.resize(
|
|
1609
|
+
img,
|
|
1610
|
+
(self.aperture_size[0] * 24, self.aperture_size[1] * 24, 3),
|
|
1611
|
+
jax.image.ResizeMethod.NEAREST,
|
|
1612
|
+
)
|
|
1353
1613
|
|
|
1354
1614
|
if is_true_mode:
|
|
1355
1615
|
# Apply true object borders by overlaying true colors on border pixels
|
foragax/objects.py
CHANGED
|
@@ -240,6 +240,7 @@ class FourierObject(BaseForagaxObject):
|
|
|
240
240
|
color: Tuple[int, int, int] = (0, 0, 0),
|
|
241
241
|
reward_delay: int = 0,
|
|
242
242
|
max_reward_delay: Optional[int] = None,
|
|
243
|
+
regen_delay: Optional[Tuple[int, int]] = None,
|
|
243
244
|
):
|
|
244
245
|
if max_reward_delay is None:
|
|
245
246
|
max_reward_delay = reward_delay
|
|
@@ -248,13 +249,14 @@ class FourierObject(BaseForagaxObject):
|
|
|
248
249
|
blocking=False,
|
|
249
250
|
collectable=True,
|
|
250
251
|
color=color,
|
|
251
|
-
random_respawn=
|
|
252
|
+
random_respawn=True,
|
|
252
253
|
max_reward_delay=max_reward_delay,
|
|
253
254
|
expiry_time=None,
|
|
254
255
|
)
|
|
255
256
|
self.num_fourier_terms = num_fourier_terms
|
|
256
257
|
self.base_magnitude = base_magnitude
|
|
257
258
|
self.reward_delay_val = reward_delay
|
|
259
|
+
self.regen_delay_range = regen_delay
|
|
258
260
|
|
|
259
261
|
def get_state(self, key: jax.Array) -> jax.Array:
|
|
260
262
|
"""Generate random Fourier series parameters.
|
|
@@ -353,6 +355,9 @@ class FourierObject(BaseForagaxObject):
|
|
|
353
355
|
|
|
354
356
|
def regen_delay(self, clock: int, rng: jax.Array) -> int:
|
|
355
357
|
"""No individual regeneration - returns infinity."""
|
|
358
|
+
if self.regen_delay_range is not None:
|
|
359
|
+
min_delay, max_delay = self.regen_delay_range
|
|
360
|
+
return jax.random.randint(rng, (), min_delay, max_delay)
|
|
356
361
|
return jnp.iinfo(jnp.int32).max
|
|
357
362
|
|
|
358
363
|
def expiry_regen_delay(self, clock: int, rng: jax.Array) -> int:
|
|
@@ -709,6 +714,7 @@ def create_fourier_objects(
|
|
|
709
714
|
num_fourier_terms: int = 10,
|
|
710
715
|
base_magnitude: float = 1.0,
|
|
711
716
|
reward_delay: int = 0,
|
|
717
|
+
regen_delay: Optional[Tuple[int, int]] = None,
|
|
712
718
|
):
|
|
713
719
|
"""Create HOT and COLD FourierObject instances.
|
|
714
720
|
|
|
@@ -726,6 +732,7 @@ def create_fourier_objects(
|
|
|
726
732
|
base_magnitude=base_magnitude,
|
|
727
733
|
color=(0, 0, 0),
|
|
728
734
|
reward_delay=reward_delay,
|
|
735
|
+
regen_delay=regen_delay,
|
|
729
736
|
)
|
|
730
737
|
|
|
731
738
|
cold = FourierObject(
|
|
@@ -734,6 +741,7 @@ def create_fourier_objects(
|
|
|
734
741
|
base_magnitude=base_magnitude,
|
|
735
742
|
color=(0, 0, 0),
|
|
736
743
|
reward_delay=reward_delay,
|
|
744
|
+
regen_delay=regen_delay,
|
|
737
745
|
)
|
|
738
746
|
|
|
739
747
|
return hot, cold
|
foragax/registry.py
CHANGED
|
@@ -104,6 +104,21 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
|
|
|
104
104
|
"dynamic_biomes": True,
|
|
105
105
|
"biome_consumption_threshold": 0.9,
|
|
106
106
|
},
|
|
107
|
+
"ForagaxDiwali-v2": {
|
|
108
|
+
"size": (15, 15),
|
|
109
|
+
"aperture_size": None,
|
|
110
|
+
"objects": None,
|
|
111
|
+
"biomes": (
|
|
112
|
+
# Hot biome
|
|
113
|
+
Biome(start=(0, 2), stop=(15, 6), object_frequencies=(0.5, 0.0)),
|
|
114
|
+
# Cold biome
|
|
115
|
+
Biome(start=(0, 9), stop=(15, 13), object_frequencies=(0.0, 0.5)),
|
|
116
|
+
),
|
|
117
|
+
"nowrap": False,
|
|
118
|
+
"deterministic_spawn": True,
|
|
119
|
+
"dynamic_biomes": True,
|
|
120
|
+
"biome_consumption_threshold": 200,
|
|
121
|
+
},
|
|
107
122
|
"ForagaxTwoBiome-v1": {
|
|
108
123
|
"size": (15, 15),
|
|
109
124
|
"aperture_size": None,
|
|
@@ -539,6 +554,12 @@ def make(
|
|
|
539
554
|
num_fourier_terms=10,
|
|
540
555
|
reward_delay=reward_delay,
|
|
541
556
|
)
|
|
557
|
+
if env_id == "ForagaxDiwali-v2":
|
|
558
|
+
config["objects"] = create_fourier_objects(
|
|
559
|
+
num_fourier_terms=10,
|
|
560
|
+
reward_delay=reward_delay,
|
|
561
|
+
regen_delay=(9, 11),
|
|
562
|
+
)
|
|
542
563
|
|
|
543
564
|
if env_id == "ForagaxSineTwoBiome-v1":
|
|
544
565
|
biome1_oyster, biome1_deathcap, biome2_oyster, biome2_deathcap = (
|
|
File without changes
|
|
File without changes
|
|
File without changes
|