continual-foragax 0.33.2__py3-none-any.whl → 0.34.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {continual_foragax-0.33.2.dist-info → continual_foragax-0.34.0.dist-info}/METADATA +1 -1
- {continual_foragax-0.33.2.dist-info → continual_foragax-0.34.0.dist-info}/RECORD +6 -6
- foragax/env.py +286 -31
- {continual_foragax-0.33.2.dist-info → continual_foragax-0.34.0.dist-info}/WHEEL +0 -0
- {continual_foragax-0.33.2.dist-info → continual_foragax-0.34.0.dist-info}/entry_points.txt +0 -0
- {continual_foragax-0.33.2.dist-info → continual_foragax-0.34.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
|
|
3
|
-
foragax/env.py,sha256=
|
|
3
|
+
foragax/env.py,sha256=l80niatCet3Kdveev8rIScTQwkmmnxIh4AhZgL7CJOA,66099
|
|
4
4
|
foragax/objects.py,sha256=9wv0ZKT89dDkaeVwUwkVo4dwhRVeUxvsTyhoyYKfOEw,26508
|
|
5
5
|
foragax/registry.py,sha256=G_xpDsSJIclEjqxU_xtkOhv4KvPLp5y8Cq2x7VsasiQ,18092
|
|
6
6
|
foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
|
|
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
|
|
|
128
128
|
foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
|
|
129
129
|
foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
|
|
130
130
|
foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
|
|
131
|
-
continual_foragax-0.
|
|
132
|
-
continual_foragax-0.
|
|
133
|
-
continual_foragax-0.
|
|
134
|
-
continual_foragax-0.
|
|
135
|
-
continual_foragax-0.
|
|
131
|
+
continual_foragax-0.34.0.dist-info/METADATA,sha256=E0gsxBGuPG2UZhCmNyWNlg6SGAkPi1tjtnqgwC69t1g,4713
|
|
132
|
+
continual_foragax-0.34.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
133
|
+
continual_foragax-0.34.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
|
|
134
|
+
continual_foragax-0.34.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
|
|
135
|
+
continual_foragax-0.34.0.dist-info/RECORD,,
|
foragax/env.py
CHANGED
|
@@ -1235,12 +1235,91 @@ class ForagaxEnv(environment.Environment):
|
|
|
1235
1235
|
|
|
1236
1236
|
return spaces.Box(0, 1, obs_shape, float)
|
|
1237
1237
|
|
|
1238
|
+
def _compute_reward_grid(self, state: EnvState) -> jax.Array:
|
|
1239
|
+
"""Compute rewards for all grid positions.
|
|
1240
|
+
|
|
1241
|
+
Returns:
|
|
1242
|
+
Array of shape (H, W) with reward values for each cell
|
|
1243
|
+
"""
|
|
1244
|
+
fixed_key = jax.random.key(0) # Fixed key for deterministic reward computation
|
|
1245
|
+
|
|
1246
|
+
def compute_reward(obj_id, params):
|
|
1247
|
+
return jax.lax.cond(
|
|
1248
|
+
obj_id > 0,
|
|
1249
|
+
lambda: jax.lax.switch(
|
|
1250
|
+
obj_id, self.reward_fns, state.time, fixed_key, params
|
|
1251
|
+
),
|
|
1252
|
+
lambda: 0.0,
|
|
1253
|
+
)
|
|
1254
|
+
|
|
1255
|
+
reward_grid = jax.vmap(jax.vmap(compute_reward))(
|
|
1256
|
+
state.object_state.object_id, state.object_state.state_params
|
|
1257
|
+
)
|
|
1258
|
+
return reward_grid
|
|
1259
|
+
|
|
1260
|
+
def _reward_to_color(self, reward: jax.Array) -> jax.Array:
|
|
1261
|
+
"""Convert reward value to RGB color using diverging gradient.
|
|
1262
|
+
|
|
1263
|
+
Args:
|
|
1264
|
+
reward: Reward value (typically -1 to +1)
|
|
1265
|
+
|
|
1266
|
+
Returns:
|
|
1267
|
+
RGB color array with shape (..., 3) and dtype uint8
|
|
1268
|
+
"""
|
|
1269
|
+
# Diverging gradient: +1 = green (0, 255, 0), 0 = white (255, 255, 255), -1 = magenta (255, 0, 255)
|
|
1270
|
+
# Clamp reward to [-1, 1] range for color mapping
|
|
1271
|
+
reward_clamped = jnp.clip(reward, -1.0, 1.0)
|
|
1272
|
+
|
|
1273
|
+
# For positive rewards: interpolate from white to green
|
|
1274
|
+
# For negative rewards: interpolate from white to magenta
|
|
1275
|
+
# At reward = 0: white (255, 255, 255)
|
|
1276
|
+
# At reward = +1: green (0, 255, 0)
|
|
1277
|
+
# At reward = -1: magenta (255, 0, 255)
|
|
1278
|
+
|
|
1279
|
+
red_component = jnp.where(
|
|
1280
|
+
reward_clamped >= 0,
|
|
1281
|
+
(1 - reward_clamped) * 255, # Fade from white to green: 255 -> 0
|
|
1282
|
+
255, # Stay at 255 for all negative rewards
|
|
1283
|
+
)
|
|
1284
|
+
|
|
1285
|
+
green_component = jnp.where(
|
|
1286
|
+
reward_clamped >= 0,
|
|
1287
|
+
255, # Stay at 255 for all positive rewards
|
|
1288
|
+
(1 + reward_clamped) * 255, # Fade from white to magenta: 255 -> 0
|
|
1289
|
+
)
|
|
1290
|
+
|
|
1291
|
+
blue_component = jnp.where(
|
|
1292
|
+
reward_clamped >= 0,
|
|
1293
|
+
(1 - reward_clamped) * 255, # Fade from white to green: 255 -> 0
|
|
1294
|
+
255, # Stay at 255 for all negative rewards
|
|
1295
|
+
)
|
|
1296
|
+
|
|
1297
|
+
return jnp.stack(
|
|
1298
|
+
[red_component, green_component, blue_component], axis=-1
|
|
1299
|
+
).astype(jnp.uint8)
|
|
1300
|
+
|
|
1238
1301
|
@partial(jax.jit, static_argnames=("self", "render_mode"))
|
|
1239
|
-
def render(
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1302
|
+
def render(
|
|
1303
|
+
self,
|
|
1304
|
+
state: EnvState,
|
|
1305
|
+
params: EnvParams,
|
|
1306
|
+
render_mode: str = "world",
|
|
1307
|
+
):
|
|
1308
|
+
"""Render the environment state.
|
|
1309
|
+
|
|
1310
|
+
Args:
|
|
1311
|
+
state: Current environment state
|
|
1312
|
+
params: Environment parameters
|
|
1313
|
+
render_mode: One of "world", "world_true", "world_reward", "aperture", "aperture_true", "aperture_reward"
|
|
1314
|
+
"""
|
|
1315
|
+
is_world_mode = render_mode in ("world", "world_true", "world_reward")
|
|
1316
|
+
is_aperture_mode = render_mode in (
|
|
1317
|
+
"aperture",
|
|
1318
|
+
"aperture_true",
|
|
1319
|
+
"aperture_reward",
|
|
1320
|
+
)
|
|
1243
1321
|
is_true_mode = render_mode in ("world_true", "aperture_true")
|
|
1322
|
+
is_reward_mode = render_mode in ("world_reward", "aperture_reward")
|
|
1244
1323
|
|
|
1245
1324
|
if is_world_mode:
|
|
1246
1325
|
# Create an RGB image from the object grid
|
|
@@ -1265,6 +1344,29 @@ class ForagaxEnv(environment.Environment):
|
|
|
1265
1344
|
|
|
1266
1345
|
img = jax.lax.fori_loop(0, len(self.object_ids), update_image, img)
|
|
1267
1346
|
|
|
1347
|
+
if is_reward_mode:
|
|
1348
|
+
# Scale image by 3 to create space for reward visualization
|
|
1349
|
+
img = jax.image.resize(
|
|
1350
|
+
img,
|
|
1351
|
+
(self.size[1] * 3, self.size[0] * 3, 3),
|
|
1352
|
+
jax.image.ResizeMethod.NEAREST,
|
|
1353
|
+
)
|
|
1354
|
+
|
|
1355
|
+
# Compute rewards for all cells
|
|
1356
|
+
reward_grid = self._compute_reward_grid(state)
|
|
1357
|
+
|
|
1358
|
+
# Convert rewards to colors
|
|
1359
|
+
reward_colors = self._reward_to_color(reward_grid)
|
|
1360
|
+
|
|
1361
|
+
# Resize reward colors to match 3x scale and place in middle cells
|
|
1362
|
+
# We need to place reward colors at positions (i*3+1, j*3+1) for each (i,j)
|
|
1363
|
+
# Create index arrays for middle cells
|
|
1364
|
+
i_indices = jnp.arange(self.size[1])[:, None] * 3 + 1
|
|
1365
|
+
j_indices = jnp.arange(self.size[0])[None, :] * 3 + 1
|
|
1366
|
+
|
|
1367
|
+
# Broadcast and set middle cells
|
|
1368
|
+
img = img.at[i_indices, j_indices].set(reward_colors)
|
|
1369
|
+
|
|
1268
1370
|
# Tint the agent's aperture
|
|
1269
1371
|
y_coords, x_coords, y_coords_adj, x_coords_adj = (
|
|
1270
1372
|
self._compute_aperture_coordinates(state.pos)
|
|
@@ -1273,27 +1375,127 @@ class ForagaxEnv(environment.Environment):
|
|
|
1273
1375
|
alpha = 0.2
|
|
1274
1376
|
agent_color = jnp.array(AGENT.color)
|
|
1275
1377
|
|
|
1276
|
-
if
|
|
1277
|
-
#
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1378
|
+
if is_reward_mode:
|
|
1379
|
+
# For reward mode, we need to adjust coordinates for 3x scaled image
|
|
1380
|
+
if self.nowrap:
|
|
1381
|
+
# Create tint mask for 3x scaled image
|
|
1382
|
+
tint_mask = jnp.zeros(
|
|
1383
|
+
(self.size[1] * 3, self.size[0] * 3), dtype=bool
|
|
1384
|
+
)
|
|
1385
|
+
|
|
1386
|
+
# For each aperture cell, tint all 9 cells in its 3x3 block
|
|
1387
|
+
# Create meshgrid to get all aperture cell coordinates
|
|
1388
|
+
y_grid, x_grid = jnp.meshgrid(
|
|
1389
|
+
y_coords_adj.flatten(), x_coords_adj.flatten(), indexing="ij"
|
|
1390
|
+
)
|
|
1391
|
+
y_flat = y_grid.flatten()
|
|
1392
|
+
x_flat = x_grid.flatten()
|
|
1393
|
+
|
|
1394
|
+
# Create offset arrays for 3x3 blocks
|
|
1395
|
+
offsets = jnp.array(
|
|
1396
|
+
[[dy, dx] for dy in range(3) for dx in range(3)]
|
|
1397
|
+
)
|
|
1398
|
+
|
|
1399
|
+
# For each aperture cell, expand to 9 cells
|
|
1400
|
+
# We need to repeat each cell coordinate 9 times, then add offsets
|
|
1401
|
+
num_aperture_cells = y_flat.size
|
|
1402
|
+
y_base = jnp.repeat(
|
|
1403
|
+
y_flat * 3, 9
|
|
1404
|
+
) # Repeat each y coord 9 times and scale by 3
|
|
1405
|
+
x_base = jnp.repeat(
|
|
1406
|
+
x_flat * 3, 9
|
|
1407
|
+
) # Repeat each x coord 9 times and scale by 3
|
|
1408
|
+
y_offsets = jnp.tile(
|
|
1409
|
+
offsets[:, 0], num_aperture_cells
|
|
1410
|
+
) # Tile all 9 offsets
|
|
1411
|
+
x_offsets = jnp.tile(
|
|
1412
|
+
offsets[:, 1], num_aperture_cells
|
|
1413
|
+
) # Tile all 9 offsets
|
|
1414
|
+
y_expanded = y_base + y_offsets
|
|
1415
|
+
x_expanded = x_base + x_offsets
|
|
1416
|
+
|
|
1417
|
+
tint_mask = tint_mask.at[y_expanded, x_expanded].set(True)
|
|
1418
|
+
|
|
1419
|
+
original_colors = img
|
|
1420
|
+
tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
|
|
1421
|
+
img = jnp.where(tint_mask[..., None], tinted_colors, img)
|
|
1422
|
+
else:
|
|
1423
|
+
# Tint all 9 cells in each 3x3 block for aperture cells
|
|
1424
|
+
# Create meshgrid to get all aperture cell coordinates
|
|
1425
|
+
y_grid, x_grid = jnp.meshgrid(
|
|
1426
|
+
y_coords_adj.flatten(), x_coords_adj.flatten(), indexing="ij"
|
|
1427
|
+
)
|
|
1428
|
+
y_flat = y_grid.flatten()
|
|
1429
|
+
x_flat = x_grid.flatten()
|
|
1430
|
+
|
|
1431
|
+
# Create offset arrays for 3x3 blocks
|
|
1432
|
+
offsets = jnp.array(
|
|
1433
|
+
[[dy, dx] for dy in range(3) for dx in range(3)]
|
|
1434
|
+
)
|
|
1435
|
+
|
|
1436
|
+
# For each aperture cell, expand to 9 cells
|
|
1437
|
+
# We need to repeat each cell coordinate 9 times, then add offsets
|
|
1438
|
+
num_aperture_cells = y_flat.size
|
|
1439
|
+
y_base = jnp.repeat(
|
|
1440
|
+
y_flat * 3, 9
|
|
1441
|
+
) # Repeat each y coord 9 times and scale by 3
|
|
1442
|
+
x_base = jnp.repeat(
|
|
1443
|
+
x_flat * 3, 9
|
|
1444
|
+
) # Repeat each x coord 9 times and scale by 3
|
|
1445
|
+
y_offsets = jnp.tile(
|
|
1446
|
+
offsets[:, 0], num_aperture_cells
|
|
1447
|
+
) # Tile all 9 offsets
|
|
1448
|
+
x_offsets = jnp.tile(
|
|
1449
|
+
offsets[:, 1], num_aperture_cells
|
|
1450
|
+
) # Tile all 9 offsets
|
|
1451
|
+
y_expanded = y_base + y_offsets
|
|
1452
|
+
x_expanded = x_base + x_offsets
|
|
1453
|
+
|
|
1454
|
+
# Get original colors and tint them
|
|
1455
|
+
original_colors = img[y_expanded, x_expanded]
|
|
1456
|
+
tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
|
|
1457
|
+
img = img.at[y_expanded, x_expanded].set(tinted_colors)
|
|
1458
|
+
|
|
1459
|
+
# Agent color - set all 9 cells of the agent's 3x3 block
|
|
1460
|
+
agent_y, agent_x = state.pos[1], state.pos[0]
|
|
1461
|
+
agent_offsets = jnp.array(
|
|
1462
|
+
[[dy, dx] for dy in range(3) for dx in range(3)]
|
|
1463
|
+
)
|
|
1464
|
+
agent_y_cells = agent_y * 3 + agent_offsets[:, 0]
|
|
1465
|
+
agent_x_cells = agent_x * 3 + agent_offsets[:, 1]
|
|
1466
|
+
img = img.at[agent_y_cells, agent_x_cells].set(
|
|
1467
|
+
jnp.array(AGENT.color, dtype=jnp.uint8)
|
|
1468
|
+
)
|
|
1469
|
+
|
|
1470
|
+
# Scale by 8 to final size
|
|
1471
|
+
img = jax.image.resize(
|
|
1472
|
+
img,
|
|
1473
|
+
(self.size[1] * 24, self.size[0] * 24, 3),
|
|
1474
|
+
jax.image.ResizeMethod.NEAREST,
|
|
1475
|
+
)
|
|
1284
1476
|
else:
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
|
|
1477
|
+
# Standard rendering without reward visualization
|
|
1478
|
+
if self.nowrap:
|
|
1479
|
+
# Create tint mask: any in-bounds original position maps to a cell makes it tinted
|
|
1480
|
+
tint_mask = jnp.zeros((self.size[1], self.size[0]), dtype=int)
|
|
1481
|
+
tint_mask = tint_mask.at[y_coords_adj, x_coords_adj].set(1)
|
|
1482
|
+
# Apply tint to masked positions
|
|
1483
|
+
original_colors = img
|
|
1484
|
+
tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
|
|
1485
|
+
img = jnp.where(tint_mask[..., None], tinted_colors, img)
|
|
1486
|
+
else:
|
|
1487
|
+
original_colors = img[y_coords_adj, x_coords_adj]
|
|
1488
|
+
tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
|
|
1489
|
+
img = img.at[y_coords_adj, x_coords_adj].set(tinted_colors)
|
|
1288
1490
|
|
|
1289
|
-
|
|
1290
|
-
|
|
1491
|
+
# Agent color
|
|
1492
|
+
img = img.at[state.pos[1], state.pos[0]].set(jnp.array(AGENT.color))
|
|
1291
1493
|
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
|
|
1296
|
-
|
|
1494
|
+
img = jax.image.resize(
|
|
1495
|
+
img,
|
|
1496
|
+
(self.size[1] * 24, self.size[0] * 24, 3),
|
|
1497
|
+
jax.image.ResizeMethod.NEAREST,
|
|
1498
|
+
)
|
|
1297
1499
|
|
|
1298
1500
|
if is_true_mode:
|
|
1299
1501
|
# Apply true object borders by overlaying true colors on border pixels
|
|
@@ -1340,16 +1542,69 @@ class ForagaxEnv(environment.Environment):
|
|
|
1340
1542
|
aperture_one_hot = jax.nn.one_hot(aperture, len(self.object_ids))
|
|
1341
1543
|
img = jnp.tensordot(aperture_one_hot, self.object_colors, axes=1)
|
|
1342
1544
|
|
|
1343
|
-
|
|
1344
|
-
|
|
1345
|
-
|
|
1545
|
+
if is_reward_mode:
|
|
1546
|
+
# Scale image by 3 to create space for reward visualization
|
|
1547
|
+
img = img.astype(jnp.uint8)
|
|
1548
|
+
img = jax.image.resize(
|
|
1549
|
+
img,
|
|
1550
|
+
(self.aperture_size[0] * 3, self.aperture_size[1] * 3, 3),
|
|
1551
|
+
jax.image.ResizeMethod.NEAREST,
|
|
1552
|
+
)
|
|
1346
1553
|
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
|
-
|
|
1351
|
-
|
|
1352
|
-
|
|
1554
|
+
# Compute rewards for aperture region
|
|
1555
|
+
y_coords, x_coords, y_coords_adj, x_coords_adj = (
|
|
1556
|
+
self._compute_aperture_coordinates(state.pos)
|
|
1557
|
+
)
|
|
1558
|
+
|
|
1559
|
+
# Get reward grid for the full world
|
|
1560
|
+
full_reward_grid = self._compute_reward_grid(state)
|
|
1561
|
+
|
|
1562
|
+
# Extract aperture rewards
|
|
1563
|
+
aperture_rewards = full_reward_grid[y_coords_adj, x_coords_adj]
|
|
1564
|
+
|
|
1565
|
+
# Convert rewards to colors
|
|
1566
|
+
reward_colors = self._reward_to_color(aperture_rewards)
|
|
1567
|
+
|
|
1568
|
+
# Place reward colors in the middle cells (index 1 in each 3x3 block)
|
|
1569
|
+
i_indices = jnp.arange(self.aperture_size[0])[:, None] * 3 + 1
|
|
1570
|
+
j_indices = jnp.arange(self.aperture_size[1])[None, :] * 3 + 1
|
|
1571
|
+
img = img.at[i_indices, j_indices].set(reward_colors)
|
|
1572
|
+
|
|
1573
|
+
# Draw agent in the center (all 9 cells of the 3x3 block)
|
|
1574
|
+
center_y, center_x = (
|
|
1575
|
+
self.aperture_size[1] // 2,
|
|
1576
|
+
self.aperture_size[0] // 2,
|
|
1577
|
+
)
|
|
1578
|
+
agent_offsets = jnp.array(
|
|
1579
|
+
[[dy, dx] for dy in range(3) for dx in range(3)]
|
|
1580
|
+
)
|
|
1581
|
+
agent_y_cells = center_y * 3 + agent_offsets[:, 0]
|
|
1582
|
+
agent_x_cells = center_x * 3 + agent_offsets[:, 1]
|
|
1583
|
+
img = img.at[agent_y_cells, agent_x_cells].set(
|
|
1584
|
+
jnp.array(AGENT.color, dtype=jnp.uint8)
|
|
1585
|
+
)
|
|
1586
|
+
|
|
1587
|
+
# Scale by 8 to final size
|
|
1588
|
+
img = jax.image.resize(
|
|
1589
|
+
img,
|
|
1590
|
+
(self.aperture_size[0] * 24, self.aperture_size[1] * 24, 3),
|
|
1591
|
+
jax.image.ResizeMethod.NEAREST,
|
|
1592
|
+
)
|
|
1593
|
+
else:
|
|
1594
|
+
# Standard rendering without reward visualization
|
|
1595
|
+
# Draw agent in the center
|
|
1596
|
+
center_y, center_x = (
|
|
1597
|
+
self.aperture_size[1] // 2,
|
|
1598
|
+
self.aperture_size[0] // 2,
|
|
1599
|
+
)
|
|
1600
|
+
img = img.at[center_y, center_x].set(jnp.array(AGENT.color))
|
|
1601
|
+
|
|
1602
|
+
img = img.astype(jnp.uint8)
|
|
1603
|
+
img = jax.image.resize(
|
|
1604
|
+
img,
|
|
1605
|
+
(self.aperture_size[0] * 24, self.aperture_size[1] * 24, 3),
|
|
1606
|
+
jax.image.ResizeMethod.NEAREST,
|
|
1607
|
+
)
|
|
1353
1608
|
|
|
1354
1609
|
if is_true_mode:
|
|
1355
1610
|
# Apply true object borders by overlaying true colors on border pixels
|
|
File without changes
|
|
File without changes
|
|
File without changes
|