continual-foragax 0.33.2__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: continual-foragax
3
- Version: 0.33.2
3
+ Version: 0.35.0
4
4
  Summary: A continual reinforcement learning benchmark
5
5
  Author-email: Steven Tang <stang5@ualberta.ca>
6
6
  Requires-Python: >=3.8
@@ -1,8 +1,8 @@
1
1
  foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
3
- foragax/env.py,sha256=bDhNQTaqcoOwm9Csb1LHoduuNdE1j1RAhGnVV7cAEPI,55147
4
- foragax/objects.py,sha256=9wv0ZKT89dDkaeVwUwkVo4dwhRVeUxvsTyhoyYKfOEw,26508
5
- foragax/registry.py,sha256=G_xpDsSJIclEjqxU_xtkOhv4KvPLp5y8Cq2x7VsasiQ,18092
3
+ foragax/env.py,sha256=K3noPwdYmQlnXVjslqVzX_FIB-CnOh37mWFArQXnf_Y,66324
4
+ foragax/objects.py,sha256=PPuLYjD7em7GL404eSpP6q8TxF8p7JtQ1kIwh7uD_tU,26860
5
+ foragax/registry.py,sha256=Ph_Z3O5GpIjrgvbKL-8Iq-Kc6MqfZIsF9KDDDzm7N3o,18787
6
6
  foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
7
7
  foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
8
8
  foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt,sha256=N7URbX6VlCZvCboUogYjMzy1I-0cfNPOn0QTLSHHfQ0,1776751
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
128
128
  foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
129
129
  foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
130
130
  foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
131
- continual_foragax-0.33.2.dist-info/METADATA,sha256=m6Om1tR_YnBz6lgN6iL_Ok0u9ucvX8IWM2E_cPYBECk,4713
132
- continual_foragax-0.33.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
- continual_foragax-0.33.2.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
- continual_foragax-0.33.2.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
- continual_foragax-0.33.2.dist-info/RECORD,,
131
+ continual_foragax-0.35.0.dist-info/METADATA,sha256=pZ1uSNXsaYkaFaz7xvO9X7bee8Jr9RjT7nbwW3tx5ps,4713
132
+ continual_foragax-0.35.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
+ continual_foragax-0.35.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
+ continual_foragax-0.35.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
+ continual_foragax-0.35.0.dist-info/RECORD,,
foragax/env.py CHANGED
@@ -690,11 +690,16 @@ class ForagaxEnv(environment.Environment):
690
690
 
691
691
  num_biomes = self.biome_object_frequencies.shape[0]
692
692
 
693
- # Compute consumption rates for all biomes
694
- consumption_rates = biome_state.consumption_count / jnp.maximum(
695
- 1.0, biome_state.total_objects.astype(float)
696
- )
697
- should_respawn = consumption_rates >= self.biome_consumption_threshold
693
+ if isinstance(self.biome_consumption_threshold, float):
694
+ # Compute consumption rates for all biomes
695
+ consumption_rates = biome_state.consumption_count / jnp.maximum(
696
+ 1.0, biome_state.total_objects.astype(float)
697
+ )
698
+ should_respawn = consumption_rates >= self.biome_consumption_threshold
699
+ else:
700
+ should_respawn = (
701
+ biome_state.consumption_count >= self.biome_consumption_threshold
702
+ )
698
703
 
699
704
  # Split key for all biomes in parallel
700
705
  key, subkey = jax.random.split(key)
@@ -1235,12 +1240,91 @@ class ForagaxEnv(environment.Environment):
1235
1240
 
1236
1241
  return spaces.Box(0, 1, obs_shape, float)
1237
1242
 
1243
+ def _compute_reward_grid(self, state: EnvState) -> jax.Array:
1244
+ """Compute rewards for all grid positions.
1245
+
1246
+ Returns:
1247
+ Array of shape (H, W) with reward values for each cell
1248
+ """
1249
+ fixed_key = jax.random.key(0) # Fixed key for deterministic reward computation
1250
+
1251
+ def compute_reward(obj_id, params):
1252
+ return jax.lax.cond(
1253
+ obj_id > 0,
1254
+ lambda: jax.lax.switch(
1255
+ obj_id, self.reward_fns, state.time, fixed_key, params
1256
+ ),
1257
+ lambda: 0.0,
1258
+ )
1259
+
1260
+ reward_grid = jax.vmap(jax.vmap(compute_reward))(
1261
+ state.object_state.object_id, state.object_state.state_params
1262
+ )
1263
+ return reward_grid
1264
+
1265
+ def _reward_to_color(self, reward: jax.Array) -> jax.Array:
1266
+ """Convert reward value to RGB color using diverging gradient.
1267
+
1268
+ Args:
1269
+ reward: Reward value (typically -1 to +1)
1270
+
1271
+ Returns:
1272
+ RGB color array with shape (..., 3) and dtype uint8
1273
+ """
1274
+ # Diverging gradient: +1 = green (0, 255, 0), 0 = white (255, 255, 255), -1 = magenta (255, 0, 255)
1275
+ # Clamp reward to [-1, 1] range for color mapping
1276
+ reward_clamped = jnp.clip(reward, -1.0, 1.0)
1277
+
1278
+ # For positive rewards: interpolate from white to green
1279
+ # For negative rewards: interpolate from white to magenta
1280
+ # At reward = 0: white (255, 255, 255)
1281
+ # At reward = +1: green (0, 255, 0)
1282
+ # At reward = -1: magenta (255, 0, 255)
1283
+
1284
+ red_component = jnp.where(
1285
+ reward_clamped >= 0,
1286
+ (1 - reward_clamped) * 255, # Fade from white to green: 255 -> 0
1287
+ 255, # Stay at 255 for all negative rewards
1288
+ )
1289
+
1290
+ green_component = jnp.where(
1291
+ reward_clamped >= 0,
1292
+ 255, # Stay at 255 for all positive rewards
1293
+ (1 + reward_clamped) * 255, # Fade from white to magenta: 255 -> 0
1294
+ )
1295
+
1296
+ blue_component = jnp.where(
1297
+ reward_clamped >= 0,
1298
+ (1 - reward_clamped) * 255, # Fade from white to green: 255 -> 0
1299
+ 255, # Stay at 255 for all negative rewards
1300
+ )
1301
+
1302
+ return jnp.stack(
1303
+ [red_component, green_component, blue_component], axis=-1
1304
+ ).astype(jnp.uint8)
1305
+
1238
1306
  @partial(jax.jit, static_argnames=("self", "render_mode"))
1239
- def render(self, state: EnvState, params: EnvParams, render_mode: str = "world"):
1240
- """Render the environment state."""
1241
- is_world_mode = render_mode in ("world", "world_true")
1242
- is_aperture_mode = render_mode in ("aperture", "aperture_true")
1307
+ def render(
1308
+ self,
1309
+ state: EnvState,
1310
+ params: EnvParams,
1311
+ render_mode: str = "world",
1312
+ ):
1313
+ """Render the environment state.
1314
+
1315
+ Args:
1316
+ state: Current environment state
1317
+ params: Environment parameters
1318
+ render_mode: One of "world", "world_true", "world_reward", "aperture", "aperture_true", "aperture_reward"
1319
+ """
1320
+ is_world_mode = render_mode in ("world", "world_true", "world_reward")
1321
+ is_aperture_mode = render_mode in (
1322
+ "aperture",
1323
+ "aperture_true",
1324
+ "aperture_reward",
1325
+ )
1243
1326
  is_true_mode = render_mode in ("world_true", "aperture_true")
1327
+ is_reward_mode = render_mode in ("world_reward", "aperture_reward")
1244
1328
 
1245
1329
  if is_world_mode:
1246
1330
  # Create an RGB image from the object grid
@@ -1265,6 +1349,29 @@ class ForagaxEnv(environment.Environment):
1265
1349
 
1266
1350
  img = jax.lax.fori_loop(0, len(self.object_ids), update_image, img)
1267
1351
 
1352
+ if is_reward_mode:
1353
+ # Scale image by 3 to create space for reward visualization
1354
+ img = jax.image.resize(
1355
+ img,
1356
+ (self.size[1] * 3, self.size[0] * 3, 3),
1357
+ jax.image.ResizeMethod.NEAREST,
1358
+ )
1359
+
1360
+ # Compute rewards for all cells
1361
+ reward_grid = self._compute_reward_grid(state)
1362
+
1363
+ # Convert rewards to colors
1364
+ reward_colors = self._reward_to_color(reward_grid)
1365
+
1366
+ # Resize reward colors to match 3x scale and place in middle cells
1367
+ # We need to place reward colors at positions (i*3+1, j*3+1) for each (i,j)
1368
+ # Create index arrays for middle cells
1369
+ i_indices = jnp.arange(self.size[1])[:, None] * 3 + 1
1370
+ j_indices = jnp.arange(self.size[0])[None, :] * 3 + 1
1371
+
1372
+ # Broadcast and set middle cells
1373
+ img = img.at[i_indices, j_indices].set(reward_colors)
1374
+
1268
1375
  # Tint the agent's aperture
1269
1376
  y_coords, x_coords, y_coords_adj, x_coords_adj = (
1270
1377
  self._compute_aperture_coordinates(state.pos)
@@ -1273,27 +1380,127 @@ class ForagaxEnv(environment.Environment):
1273
1380
  alpha = 0.2
1274
1381
  agent_color = jnp.array(AGENT.color)
1275
1382
 
1276
- if self.nowrap:
1277
- # Create tint mask: any in-bounds original position maps to a cell makes it tinted
1278
- tint_mask = jnp.zeros((self.size[1], self.size[0]), dtype=int)
1279
- tint_mask = tint_mask.at[y_coords_adj, x_coords_adj].set(1)
1280
- # Apply tint to masked positions
1281
- original_colors = img
1282
- tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
1283
- img = jnp.where(tint_mask[..., None], tinted_colors, img)
1383
+ if is_reward_mode:
1384
+ # For reward mode, we need to adjust coordinates for 3x scaled image
1385
+ if self.nowrap:
1386
+ # Create tint mask for 3x scaled image
1387
+ tint_mask = jnp.zeros(
1388
+ (self.size[1] * 3, self.size[0] * 3), dtype=bool
1389
+ )
1390
+
1391
+ # For each aperture cell, tint all 9 cells in its 3x3 block
1392
+ # Create meshgrid to get all aperture cell coordinates
1393
+ y_grid, x_grid = jnp.meshgrid(
1394
+ y_coords_adj.flatten(), x_coords_adj.flatten(), indexing="ij"
1395
+ )
1396
+ y_flat = y_grid.flatten()
1397
+ x_flat = x_grid.flatten()
1398
+
1399
+ # Create offset arrays for 3x3 blocks
1400
+ offsets = jnp.array(
1401
+ [[dy, dx] for dy in range(3) for dx in range(3)]
1402
+ )
1403
+
1404
+ # For each aperture cell, expand to 9 cells
1405
+ # We need to repeat each cell coordinate 9 times, then add offsets
1406
+ num_aperture_cells = y_flat.size
1407
+ y_base = jnp.repeat(
1408
+ y_flat * 3, 9
1409
+ ) # Repeat each y coord 9 times and scale by 3
1410
+ x_base = jnp.repeat(
1411
+ x_flat * 3, 9
1412
+ ) # Repeat each x coord 9 times and scale by 3
1413
+ y_offsets = jnp.tile(
1414
+ offsets[:, 0], num_aperture_cells
1415
+ ) # Tile all 9 offsets
1416
+ x_offsets = jnp.tile(
1417
+ offsets[:, 1], num_aperture_cells
1418
+ ) # Tile all 9 offsets
1419
+ y_expanded = y_base + y_offsets
1420
+ x_expanded = x_base + x_offsets
1421
+
1422
+ tint_mask = tint_mask.at[y_expanded, x_expanded].set(True)
1423
+
1424
+ original_colors = img
1425
+ tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
1426
+ img = jnp.where(tint_mask[..., None], tinted_colors, img)
1427
+ else:
1428
+ # Tint all 9 cells in each 3x3 block for aperture cells
1429
+ # Create meshgrid to get all aperture cell coordinates
1430
+ y_grid, x_grid = jnp.meshgrid(
1431
+ y_coords_adj.flatten(), x_coords_adj.flatten(), indexing="ij"
1432
+ )
1433
+ y_flat = y_grid.flatten()
1434
+ x_flat = x_grid.flatten()
1435
+
1436
+ # Create offset arrays for 3x3 blocks
1437
+ offsets = jnp.array(
1438
+ [[dy, dx] for dy in range(3) for dx in range(3)]
1439
+ )
1440
+
1441
+ # For each aperture cell, expand to 9 cells
1442
+ # We need to repeat each cell coordinate 9 times, then add offsets
1443
+ num_aperture_cells = y_flat.size
1444
+ y_base = jnp.repeat(
1445
+ y_flat * 3, 9
1446
+ ) # Repeat each y coord 9 times and scale by 3
1447
+ x_base = jnp.repeat(
1448
+ x_flat * 3, 9
1449
+ ) # Repeat each x coord 9 times and scale by 3
1450
+ y_offsets = jnp.tile(
1451
+ offsets[:, 0], num_aperture_cells
1452
+ ) # Tile all 9 offsets
1453
+ x_offsets = jnp.tile(
1454
+ offsets[:, 1], num_aperture_cells
1455
+ ) # Tile all 9 offsets
1456
+ y_expanded = y_base + y_offsets
1457
+ x_expanded = x_base + x_offsets
1458
+
1459
+ # Get original colors and tint them
1460
+ original_colors = img[y_expanded, x_expanded]
1461
+ tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
1462
+ img = img.at[y_expanded, x_expanded].set(tinted_colors)
1463
+
1464
+ # Agent color - set all 9 cells of the agent's 3x3 block
1465
+ agent_y, agent_x = state.pos[1], state.pos[0]
1466
+ agent_offsets = jnp.array(
1467
+ [[dy, dx] for dy in range(3) for dx in range(3)]
1468
+ )
1469
+ agent_y_cells = agent_y * 3 + agent_offsets[:, 0]
1470
+ agent_x_cells = agent_x * 3 + agent_offsets[:, 1]
1471
+ img = img.at[agent_y_cells, agent_x_cells].set(
1472
+ jnp.array(AGENT.color, dtype=jnp.uint8)
1473
+ )
1474
+
1475
+ # Scale by 8 to final size
1476
+ img = jax.image.resize(
1477
+ img,
1478
+ (self.size[1] * 24, self.size[0] * 24, 3),
1479
+ jax.image.ResizeMethod.NEAREST,
1480
+ )
1284
1481
  else:
1285
- original_colors = img[y_coords_adj, x_coords_adj]
1286
- tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
1287
- img = img.at[y_coords_adj, x_coords_adj].set(tinted_colors)
1482
+ # Standard rendering without reward visualization
1483
+ if self.nowrap:
1484
+ # Create tint mask: any in-bounds original position maps to a cell makes it tinted
1485
+ tint_mask = jnp.zeros((self.size[1], self.size[0]), dtype=int)
1486
+ tint_mask = tint_mask.at[y_coords_adj, x_coords_adj].set(1)
1487
+ # Apply tint to masked positions
1488
+ original_colors = img
1489
+ tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
1490
+ img = jnp.where(tint_mask[..., None], tinted_colors, img)
1491
+ else:
1492
+ original_colors = img[y_coords_adj, x_coords_adj]
1493
+ tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
1494
+ img = img.at[y_coords_adj, x_coords_adj].set(tinted_colors)
1288
1495
 
1289
- # Agent color
1290
- img = img.at[state.pos[1], state.pos[0]].set(jnp.array(AGENT.color))
1496
+ # Agent color
1497
+ img = img.at[state.pos[1], state.pos[0]].set(jnp.array(AGENT.color))
1291
1498
 
1292
- img = jax.image.resize(
1293
- img,
1294
- (self.size[1] * 24, self.size[0] * 24, 3),
1295
- jax.image.ResizeMethod.NEAREST,
1296
- )
1499
+ img = jax.image.resize(
1500
+ img,
1501
+ (self.size[1] * 24, self.size[0] * 24, 3),
1502
+ jax.image.ResizeMethod.NEAREST,
1503
+ )
1297
1504
 
1298
1505
  if is_true_mode:
1299
1506
  # Apply true object borders by overlaying true colors on border pixels
@@ -1340,16 +1547,69 @@ class ForagaxEnv(environment.Environment):
1340
1547
  aperture_one_hot = jax.nn.one_hot(aperture, len(self.object_ids))
1341
1548
  img = jnp.tensordot(aperture_one_hot, self.object_colors, axes=1)
1342
1549
 
1343
- # Draw agent in the center
1344
- center_y, center_x = self.aperture_size[1] // 2, self.aperture_size[0] // 2
1345
- img = img.at[center_y, center_x].set(jnp.array(AGENT.color))
1550
+ if is_reward_mode:
1551
+ # Scale image by 3 to create space for reward visualization
1552
+ img = img.astype(jnp.uint8)
1553
+ img = jax.image.resize(
1554
+ img,
1555
+ (self.aperture_size[0] * 3, self.aperture_size[1] * 3, 3),
1556
+ jax.image.ResizeMethod.NEAREST,
1557
+ )
1346
1558
 
1347
- img = img.astype(jnp.uint8)
1348
- img = jax.image.resize(
1349
- img,
1350
- (self.aperture_size[0] * 24, self.aperture_size[1] * 24, 3),
1351
- jax.image.ResizeMethod.NEAREST,
1352
- )
1559
+ # Compute rewards for aperture region
1560
+ y_coords, x_coords, y_coords_adj, x_coords_adj = (
1561
+ self._compute_aperture_coordinates(state.pos)
1562
+ )
1563
+
1564
+ # Get reward grid for the full world
1565
+ full_reward_grid = self._compute_reward_grid(state)
1566
+
1567
+ # Extract aperture rewards
1568
+ aperture_rewards = full_reward_grid[y_coords_adj, x_coords_adj]
1569
+
1570
+ # Convert rewards to colors
1571
+ reward_colors = self._reward_to_color(aperture_rewards)
1572
+
1573
+ # Place reward colors in the middle cells (index 1 in each 3x3 block)
1574
+ i_indices = jnp.arange(self.aperture_size[0])[:, None] * 3 + 1
1575
+ j_indices = jnp.arange(self.aperture_size[1])[None, :] * 3 + 1
1576
+ img = img.at[i_indices, j_indices].set(reward_colors)
1577
+
1578
+ # Draw agent in the center (all 9 cells of the 3x3 block)
1579
+ center_y, center_x = (
1580
+ self.aperture_size[1] // 2,
1581
+ self.aperture_size[0] // 2,
1582
+ )
1583
+ agent_offsets = jnp.array(
1584
+ [[dy, dx] for dy in range(3) for dx in range(3)]
1585
+ )
1586
+ agent_y_cells = center_y * 3 + agent_offsets[:, 0]
1587
+ agent_x_cells = center_x * 3 + agent_offsets[:, 1]
1588
+ img = img.at[agent_y_cells, agent_x_cells].set(
1589
+ jnp.array(AGENT.color, dtype=jnp.uint8)
1590
+ )
1591
+
1592
+ # Scale by 8 to final size
1593
+ img = jax.image.resize(
1594
+ img,
1595
+ (self.aperture_size[0] * 24, self.aperture_size[1] * 24, 3),
1596
+ jax.image.ResizeMethod.NEAREST,
1597
+ )
1598
+ else:
1599
+ # Standard rendering without reward visualization
1600
+ # Draw agent in the center
1601
+ center_y, center_x = (
1602
+ self.aperture_size[1] // 2,
1603
+ self.aperture_size[0] // 2,
1604
+ )
1605
+ img = img.at[center_y, center_x].set(jnp.array(AGENT.color))
1606
+
1607
+ img = img.astype(jnp.uint8)
1608
+ img = jax.image.resize(
1609
+ img,
1610
+ (self.aperture_size[0] * 24, self.aperture_size[1] * 24, 3),
1611
+ jax.image.ResizeMethod.NEAREST,
1612
+ )
1353
1613
 
1354
1614
  if is_true_mode:
1355
1615
  # Apply true object borders by overlaying true colors on border pixels
foragax/objects.py CHANGED
@@ -240,6 +240,7 @@ class FourierObject(BaseForagaxObject):
240
240
  color: Tuple[int, int, int] = (0, 0, 0),
241
241
  reward_delay: int = 0,
242
242
  max_reward_delay: Optional[int] = None,
243
+ regen_delay: Optional[Tuple[int, int]] = None,
243
244
  ):
244
245
  if max_reward_delay is None:
245
246
  max_reward_delay = reward_delay
@@ -248,13 +249,14 @@ class FourierObject(BaseForagaxObject):
248
249
  blocking=False,
249
250
  collectable=True,
250
251
  color=color,
251
- random_respawn=False, # Objects don't respawn individually
252
+ random_respawn=True,
252
253
  max_reward_delay=max_reward_delay,
253
254
  expiry_time=None,
254
255
  )
255
256
  self.num_fourier_terms = num_fourier_terms
256
257
  self.base_magnitude = base_magnitude
257
258
  self.reward_delay_val = reward_delay
259
+ self.regen_delay_range = regen_delay
258
260
 
259
261
  def get_state(self, key: jax.Array) -> jax.Array:
260
262
  """Generate random Fourier series parameters.
@@ -353,6 +355,9 @@ class FourierObject(BaseForagaxObject):
353
355
 
354
356
  def regen_delay(self, clock: int, rng: jax.Array) -> int:
355
357
  """No individual regeneration - returns infinity."""
358
+ if self.regen_delay_range is not None:
359
+ min_delay, max_delay = self.regen_delay_range
360
+ return jax.random.randint(rng, (), min_delay, max_delay)
356
361
  return jnp.iinfo(jnp.int32).max
357
362
 
358
363
  def expiry_regen_delay(self, clock: int, rng: jax.Array) -> int:
@@ -709,6 +714,7 @@ def create_fourier_objects(
709
714
  num_fourier_terms: int = 10,
710
715
  base_magnitude: float = 1.0,
711
716
  reward_delay: int = 0,
717
+ regen_delay: Optional[Tuple[int, int]] = None,
712
718
  ):
713
719
  """Create HOT and COLD FourierObject instances.
714
720
 
@@ -726,6 +732,7 @@ def create_fourier_objects(
726
732
  base_magnitude=base_magnitude,
727
733
  color=(0, 0, 0),
728
734
  reward_delay=reward_delay,
735
+ regen_delay=regen_delay,
729
736
  )
730
737
 
731
738
  cold = FourierObject(
@@ -734,6 +741,7 @@ def create_fourier_objects(
734
741
  base_magnitude=base_magnitude,
735
742
  color=(0, 0, 0),
736
743
  reward_delay=reward_delay,
744
+ regen_delay=regen_delay,
737
745
  )
738
746
 
739
747
  return hot, cold
foragax/registry.py CHANGED
@@ -104,6 +104,21 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
104
104
  "dynamic_biomes": True,
105
105
  "biome_consumption_threshold": 0.9,
106
106
  },
107
+ "ForagaxDiwali-v2": {
108
+ "size": (15, 15),
109
+ "aperture_size": None,
110
+ "objects": None,
111
+ "biomes": (
112
+ # Hot biome
113
+ Biome(start=(0, 2), stop=(15, 6), object_frequencies=(0.5, 0.0)),
114
+ # Cold biome
115
+ Biome(start=(0, 9), stop=(15, 13), object_frequencies=(0.0, 0.5)),
116
+ ),
117
+ "nowrap": False,
118
+ "deterministic_spawn": True,
119
+ "dynamic_biomes": True,
120
+ "biome_consumption_threshold": 200,
121
+ },
107
122
  "ForagaxTwoBiome-v1": {
108
123
  "size": (15, 15),
109
124
  "aperture_size": None,
@@ -539,6 +554,12 @@ def make(
539
554
  num_fourier_terms=10,
540
555
  reward_delay=reward_delay,
541
556
  )
557
+ if env_id == "ForagaxDiwali-v2":
558
+ config["objects"] = create_fourier_objects(
559
+ num_fourier_terms=10,
560
+ reward_delay=reward_delay,
561
+ regen_delay=(9, 11),
562
+ )
542
563
 
543
564
  if env_id == "ForagaxSineTwoBiome-v1":
544
565
  biome1_oyster, biome1_deathcap, biome2_oyster, biome2_deathcap = (