continual-foragax 0.27.0__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: continual-foragax
3
- Version: 0.27.0
3
+ Version: 0.28.0
4
4
  Summary: A continual reinforcement learning benchmark
5
5
  Author-email: Steven Tang <stang5@ualberta.ca>
6
6
  Requires-Python: >=3.8
@@ -1,8 +1,8 @@
1
1
  foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
3
- foragax/env.py,sha256=2Z2-3wjL_gHCqehMFsZA761VlUWQvsTU6PawoTtzcCY,25103
3
+ foragax/env.py,sha256=lg6nRT2goO84FNGdL0CbUTGj1iBiQXUYOWf2qavMMvE,25658
4
4
  foragax/objects.py,sha256=FCLZ-8d7qq9VMTG6G-TaRt842-sjgB0-DH0IoHwwngI,9503
5
- foragax/registry.py,sha256=7-6VDN1MKVEvX_1u5G8NkSpv9BccEmtjJa77-OTNg3A,14324
5
+ foragax/registry.py,sha256=HysNaZs1tcbAcr53l8Cb2NeZ-_FmE6OpUe_zIks-ObM,15089
6
6
  foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
7
7
  foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
8
8
  foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt,sha256=N7URbX6VlCZvCboUogYjMzy1I-0cfNPOn0QTLSHHfQ0,1776751
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
128
128
  foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
129
129
  foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
130
130
  foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
131
- continual_foragax-0.27.0.dist-info/METADATA,sha256=Axd__jkh5OIh1T9CmzB9IarAguoWNvsECjllZ20mG3w,4897
132
- continual_foragax-0.27.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
- continual_foragax-0.27.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
- continual_foragax-0.27.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
- continual_foragax-0.27.0.dist-info/RECORD,,
131
+ continual_foragax-0.28.0.dist-info/METADATA,sha256=ijxVOxZXSpbQ3ORIHKrUxgciSiuBTqxpA-nE2TWvitQ,4897
132
+ continual_foragax-0.28.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
+ continual_foragax-0.28.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
+ continual_foragax-0.28.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
+ continual_foragax-0.28.0.dist-info/RECORD,,
foragax/env.py CHANGED
@@ -76,6 +76,7 @@ class ForagaxEnv(environment.Environment):
76
76
  nowrap: bool = False,
77
77
  deterministic_spawn: bool = False,
78
78
  teleport_interval: Optional[int] = None,
79
+ observation_type: str = "object",
79
80
  ):
80
81
  super().__init__()
81
82
  self._name = name
@@ -83,9 +84,17 @@ class ForagaxEnv(environment.Environment):
83
84
  size = (size, size)
84
85
  self.size = size
85
86
 
86
- if isinstance(aperture_size, int):
87
- aperture_size = (aperture_size, aperture_size)
88
- self.aperture_size = aperture_size
87
+ # Handle aperture_size = -1 for world view
88
+ if isinstance(aperture_size, int) and aperture_size == -1:
89
+ self.full_world = True
90
+ self.aperture_size = self.size # Use full size for consistency
91
+ else:
92
+ self.full_world = False
93
+ if isinstance(aperture_size, int):
94
+ aperture_size = (aperture_size, aperture_size)
95
+ self.aperture_size = aperture_size
96
+
97
+ self.observation_type = observation_type
89
98
  self.nowrap = nowrap
90
99
  self.deterministic_spawn = deterministic_spawn
91
100
  self.teleport_interval = teleport_interval
@@ -152,6 +161,29 @@ class ForagaxEnv(environment.Environment):
152
161
  )
153
162
  self.biome_masks.append(mask)
154
163
 
164
+ # Compute unique colors and mapping for partial observability (for 'color' observation_type)
165
+ # Exclude EMPTY (index 0) from color channels
166
+ object_colors_no_empty = self.object_colors[1:]
167
+
168
+ # Find unique colors in order of first appearance
169
+ unique_colors = []
170
+ color_indices = jnp.zeros(len(object_colors_no_empty), dtype=int)
171
+ color_map = {}
172
+ next_channel = 0
173
+
174
+ for i, color in enumerate(object_colors_no_empty):
175
+ color_tuple = tuple(color.tolist())
176
+ if color_tuple not in color_map:
177
+ color_map[color_tuple] = next_channel
178
+ unique_colors.append(color)
179
+ next_channel += 1
180
+ color_indices = color_indices.at[i].set(color_map[color_tuple])
181
+
182
+ self.unique_colors = jnp.array(unique_colors)
183
+ self.num_color_channels = len(unique_colors)
184
+ # color_indices maps from object_id-1 to color_channel_index
185
+ self.object_to_color_map = color_indices
186
+
155
187
  @property
156
188
  def default_params(self) -> EnvParams:
157
189
  return EnvParams(
@@ -412,6 +444,102 @@ class ForagaxEnv(environment.Environment):
412
444
 
413
445
  return aperture
414
446
 
447
+ def get_obs(self, state: EnvState, params: EnvParams, key=None) -> jax.Array:
448
+ """Get observation based on observation_type and full_world."""
449
+ # Decode grid for observation
450
+ obs_grid = jnp.maximum(0, state.object_grid)
451
+
452
+ if self.full_world:
453
+ return self._get_world_obs(obs_grid, state)
454
+ else:
455
+ grid = self._get_aperture(obs_grid, state.pos)
456
+ return self._get_aperture_obs(grid, state)
457
+
458
+ def _get_world_obs(self, obs_grid: jax.Array, state: EnvState) -> jax.Array:
459
+ """Get world observation."""
460
+ num_obj_types = len(self.object_ids)
461
+ if self.observation_type == "object":
462
+ obs = jnp.zeros(
463
+ (self.size[1], self.size[0], num_obj_types), dtype=jnp.float32
464
+ )
465
+ num_object_channels = num_obj_types - 1
466
+ # create masks for all objects at once
467
+ object_ids = jnp.arange(1, num_obj_types)
468
+ object_masks = obs_grid[..., None] == object_ids[None, None, :]
469
+ obs = obs.at[:, :, :num_object_channels].set(object_masks.astype(float))
470
+ obs = obs.at[state.pos[1], state.pos[0], -1].set(1)
471
+ return obs
472
+ elif self.observation_type == "rgb":
473
+ obs = jax.nn.one_hot(obs_grid, num_obj_types)
474
+ # Agent position
475
+ obs = obs.at[state.pos[1], state.pos[0], :].set(0)
476
+ obs = obs.at[state.pos[1], state.pos[0], -1].set(1)
477
+ colors = self.object_colors / 255.0
478
+ obs = jnp.tensordot(obs, colors, axes=1)
479
+ return obs
480
+ elif self.observation_type == "color":
481
+ # Handle case with no objects (only EMPTY)
482
+ if self.num_color_channels == 0:
483
+ return jnp.zeros(obs_grid.shape + (0,), dtype=jnp.float32)
484
+ # Map object IDs to color channel indices
485
+ color_channels = jnp.where(
486
+ obs_grid == 0,
487
+ -1,
488
+ jnp.take(self.object_to_color_map, obs_grid - 1, axis=0),
489
+ )
490
+ obs = jax.nn.one_hot(color_channels, self.num_color_channels, axis=-1)
491
+ return obs
492
+ else:
493
+ raise ValueError(f"Unknown observation_type: {self.observation_type}")
494
+
495
+ def _get_aperture_obs(self, aperture: jax.Array, state: EnvState) -> jax.Array:
496
+ """Get aperture observation."""
497
+ if self.observation_type == "object":
498
+ num_obj_types = len(self.object_ids)
499
+ obs = jax.nn.one_hot(aperture, num_obj_types, axis=-1)
500
+ return obs
501
+ elif self.observation_type == "rgb":
502
+ num_obj_types = len(self.object_ids)
503
+ aperture_one_hot = jax.nn.one_hot(aperture, num_obj_types)
504
+ colors = self.object_colors / 255.0
505
+ obs = jnp.tensordot(aperture_one_hot, colors, axes=1)
506
+ return obs
507
+ elif self.observation_type == "color":
508
+ # Handle case with no objects (only EMPTY)
509
+ if self.num_color_channels == 0:
510
+ return jnp.zeros(aperture.shape + (0,), dtype=jnp.float32)
511
+ # Map object IDs to color channel indices
512
+ color_channels = jnp.where(
513
+ aperture == 0,
514
+ -1, # Special value for EMPTY
515
+ jnp.take(self.object_to_color_map, aperture - 1, axis=0),
516
+ )
517
+ obs = jax.nn.one_hot(color_channels, self.num_color_channels, axis=-1)
518
+ return obs
519
+ else:
520
+ raise ValueError(f"Unknown observation_type: {self.observation_type}")
521
+
522
+ def observation_space(self, params: EnvParams) -> spaces.Box:
523
+ """Observation space based on observation_type and full_world."""
524
+ if self.full_world:
525
+ size = tuple(reversed(self.size))
526
+ else:
527
+ size = self.aperture_size
528
+
529
+ if self.observation_type == "rgb":
530
+ channels = 3
531
+ elif self.observation_type == "object":
532
+ num_obj_types = len(self.objects)
533
+ channels = num_obj_types
534
+ elif self.observation_type == "color":
535
+ channels = self.num_color_channels
536
+ else:
537
+ raise ValueError(f"Unknown observation_type: {self.observation_type}")
538
+
539
+ obs_shape = (*size, channels)
540
+
541
+ return spaces.Box(0, 1, obs_shape, float)
542
+
415
543
  @partial(jax.jit, static_argnames=("self", "render_mode"))
416
544
  def render(self, state: EnvState, params: EnvParams, render_mode: str = "world"):
417
545
  """Render the environment state."""
@@ -520,133 +648,3 @@ class ForagaxEnv(environment.Environment):
520
648
  raise ValueError(f"Unknown render_mode: {render_mode}")
521
649
 
522
650
  return img
523
-
524
-
525
- class ForagaxObjectEnv(ForagaxEnv):
526
- """Foragax environment with object-based aperture observation."""
527
-
528
- def __init__(
529
- self,
530
- name: str = "Foragax-v0",
531
- size: Union[Tuple[int, int], int] = (10, 10),
532
- aperture_size: Union[Tuple[int, int], int] = (5, 5),
533
- objects: Tuple[BaseForagaxObject, ...] = (),
534
- biomes: Tuple[Biome, ...] = (Biome(object_frequencies=()),),
535
- nowrap: bool = False,
536
- deterministic_spawn: bool = False,
537
- teleport_interval: Optional[int] = None,
538
- ):
539
- super().__init__(
540
- name,
541
- size,
542
- aperture_size,
543
- objects,
544
- biomes,
545
- nowrap,
546
- deterministic_spawn,
547
- teleport_interval,
548
- )
549
-
550
- # Compute unique colors and mapping for partial observability
551
- # Exclude EMPTY (index 0) from color channels
552
- object_colors_no_empty = self.object_colors[1:]
553
-
554
- # Find unique colors in order of first appearance
555
- unique_colors = []
556
- color_indices = jnp.zeros(len(object_colors_no_empty), dtype=int)
557
- color_map = {}
558
- next_channel = 0
559
-
560
- for i, color in enumerate(object_colors_no_empty):
561
- color_tuple = tuple(color.tolist())
562
- if color_tuple not in color_map:
563
- color_map[color_tuple] = next_channel
564
- unique_colors.append(color)
565
- next_channel += 1
566
- color_indices = color_indices.at[i].set(color_map[color_tuple])
567
-
568
- self.unique_colors = jnp.array(unique_colors)
569
- self.num_color_channels = len(unique_colors)
570
- # color_indices maps from object_id-1 to color_channel_index
571
- self.object_to_color_map = color_indices
572
-
573
- def get_obs(self, state: EnvState, params: EnvParams, key=None) -> jax.Array:
574
- # Decode grid for observation
575
- obs_grid = jnp.maximum(0, state.object_grid)
576
- aperture = self._get_aperture(obs_grid, state.pos)
577
-
578
- # Handle case with no objects (only EMPTY)
579
- if self.num_color_channels == 0:
580
- return jnp.zeros(aperture.shape + (0,), dtype=jnp.float32)
581
-
582
- # Map object IDs to color channel indices
583
- # aperture contains object IDs (0 = EMPTY, 1+ = objects)
584
- # For EMPTY (0), we want no color channel activated
585
- # For objects (1+), map to color channel using object_to_color_map
586
- color_channels = jnp.where(
587
- aperture == 0,
588
- -1, # Special value for EMPTY
589
- jnp.take(self.object_to_color_map, aperture - 1, axis=0),
590
- )
591
-
592
- # Create one-hot encoding for color channels
593
- # jax.nn.one_hot produces all zeros for -1 (EMPTY positions)
594
- obs = jax.nn.one_hot(color_channels, self.num_color_channels, axis=-1)
595
-
596
- return obs
597
-
598
- def observation_space(self, params: EnvParams) -> spaces.Box:
599
- obs_shape = (
600
- self.aperture_size[0],
601
- self.aperture_size[1],
602
- self.num_color_channels,
603
- )
604
- return spaces.Box(0, 1, obs_shape, float)
605
-
606
-
607
- class ForagaxRGBEnv(ForagaxEnv):
608
- """Foragax environment with color-based aperture observation."""
609
-
610
- def get_obs(self, state: EnvState, params: EnvParams, key=None) -> jax.Array:
611
- num_obj_types = len(self.object_ids)
612
- # Decode grid for observation
613
- obs_grid = jnp.maximum(0, state.object_grid)
614
- aperture = self._get_aperture(obs_grid, state.pos)
615
- aperture_one_hot = jax.nn.one_hot(aperture, num_obj_types)
616
-
617
- # Agent position is always at the center of the aperture
618
- center = (self.aperture_size[1] // 2, self.aperture_size[0] // 2)
619
- aperture_one_hot = aperture_one_hot.at[center[0], center[1], :].set(0)
620
- aperture_one_hot = aperture_one_hot.at[center[0], center[1], -1].set(1)
621
-
622
- colors = self.object_colors / 255.0
623
- obs = jnp.tensordot(aperture_one_hot, colors, axes=1)
624
- return obs
625
-
626
- def observation_space(self, params: EnvParams) -> spaces.Box:
627
- obs_shape = (self.aperture_size[0], self.aperture_size[1], 3)
628
- return spaces.Box(0, 1, obs_shape, float)
629
-
630
-
631
- class ForagaxWorldEnv(ForagaxEnv):
632
- """Foragax environment with world observation."""
633
-
634
- def get_obs(self, state: EnvState, params: EnvParams, key=None) -> jax.Array:
635
- num_obj_types = len(self.object_ids)
636
- # Decode grid for observation
637
- obs_grid = jnp.maximum(0, state.object_grid)
638
- obs = jnp.zeros((self.size[1], self.size[0], num_obj_types), dtype=jnp.float32)
639
-
640
- num_object_channels = num_obj_types - 1
641
- # create masks for all objects at once
642
- object_ids = jnp.arange(1, num_obj_types)
643
- object_masks = obs_grid[..., None] == object_ids[None, None, :]
644
- obs = obs.at[:, :, :num_object_channels].set(object_masks.astype(float))
645
-
646
- obs = obs.at[state.pos[1], state.pos[0], -1].set(1)
647
- return obs
648
-
649
- def observation_space(self, params: EnvParams) -> spaces.Box:
650
- num_obj_types = len(self.object_ids)
651
- obs_shape = (self.size[1], self.size[0], num_obj_types)
652
- return spaces.Box(0, 1, obs_shape, float)
foragax/registry.py CHANGED
@@ -1,13 +1,11 @@
1
1
  """Factory functions for creating Foragax environment variants."""
2
2
 
3
+ import warnings
3
4
  from typing import Any, Dict, Optional, Tuple
4
5
 
5
6
  from foragax.env import (
6
7
  Biome,
7
8
  ForagaxEnv,
8
- ForagaxObjectEnv,
9
- ForagaxRGBEnv,
10
- ForagaxWorldEnv,
11
9
  )
12
10
  from foragax.objects import (
13
11
  BROWN_MOREL,
@@ -347,21 +345,23 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
347
345
 
348
346
  def make(
349
347
  env_id: str,
350
- observation_type: str = "object",
348
+ observation_type: str = "color",
351
349
  aperture_size: Optional[Tuple[int, int]] = (5, 5),
352
350
  file_index: int = 0,
353
351
  nowrap: Optional[bool] = None,
352
+ **kwargs: Any,
354
353
  ) -> ForagaxEnv:
355
354
  """Create a Foragax environment.
356
355
 
357
356
  Args:
358
357
  env_id: The ID of the environment to create.
359
- observation_type: The type of observation to use. One of "object", "rgb", or "world".
360
- aperture_size: The size of the agent's observation aperture. If None, the default
361
- for the environment is used.
358
+ observation_type: The type of observation to use. One of "object", "rgb", or "color".
359
+ aperture_size: The size of the agent's observation aperture. If -1, full world observation.
360
+ If None, the default for the environment is used.
362
361
  file_index: File index for weather objects. nowrap: If True, disables
363
362
  wrapping around environment boundaries. If None, uses defaults per
364
363
  environment.
364
+ **kwargs: Additional keyword arguments to pass to the ForagaxEnv constructor.
365
365
 
366
366
  Returns:
367
367
  A Foragax environment instance.
@@ -371,11 +371,15 @@ def make(
371
371
 
372
372
  config = ENV_CONFIGS[env_id].copy()
373
373
  if isinstance(aperture_size, int):
374
- aperture_size = (aperture_size, aperture_size)
374
+ if aperture_size == -1:
375
+ aperture_size = -1 # keep as -1
376
+ else:
377
+ aperture_size = (aperture_size, aperture_size)
375
378
  config["aperture_size"] = aperture_size
376
379
  if nowrap is not None:
377
380
  config["nowrap"] = nowrap
378
381
 
382
+ # Handle special size and biome configurations
379
383
  if env_id in (
380
384
  "ForagaxTwoBiome-v7",
381
385
  "ForagaxTwoBiome-v8",
@@ -384,7 +388,10 @@ def make(
384
388
  "ForagaxTwoBiome-v15",
385
389
  "ForagaxTwoBiome-v16",
386
390
  ):
387
- margin = aperture_size[1] // 2 + 1
391
+ if aperture_size == -1:
392
+ margin = 0 # for world view, no margin needed
393
+ else:
394
+ margin = aperture_size[1] // 2 + 1
388
395
  width = 2 * margin + 9
389
396
  config["size"] = (width, 15)
390
397
  config["biomes"] = (
@@ -403,7 +410,10 @@ def make(
403
410
  )
404
411
 
405
412
  if env_id in ("ForagaxTwoBiome-v11", "ForagaxTwoBiome-v12"):
406
- margin = aperture_size[1] // 2 + 1
413
+ if aperture_size == -1:
414
+ margin = 0
415
+ else:
416
+ margin = aperture_size[1] // 2 + 1
407
417
  width = 2 * margin + 9
408
418
  config["size"] = (width, 15)
409
419
  config["biomes"] = (
@@ -422,7 +432,10 @@ def make(
422
432
  )
423
433
 
424
434
  if env_id in ("ForagaxWeather-v3", "ForagaxWeather-v4"):
425
- margin = aperture_size[1] // 2 + 1
435
+ if aperture_size == -1:
436
+ margin = 0
437
+ else:
438
+ margin = aperture_size[1] // 2 + 1
426
439
  width = 2 * margin + 9
427
440
  config["size"] = (15, width)
428
441
  config["biomes"] = (
@@ -456,16 +469,20 @@ def make(
456
469
  if env_id == "ForagaxTwoBiome-v16":
457
470
  config["teleport_interval"] = 10000
458
471
 
459
- env_class_map = {
460
- "object": ForagaxObjectEnv,
461
- "rgb": ForagaxRGBEnv,
462
- "world": ForagaxWorldEnv,
463
- }
472
+ # Backward compatibility: map "world" to "object" with full world
473
+ if observation_type == "world":
474
+ # add deprecation warning
475
+ warnings.warn(
476
+ "'world' observation_type is deprecated. Use 'object' with aperture_size=-1 instead.",
477
+ DeprecationWarning,
478
+ )
479
+ observation_type = "object"
480
+ config["aperture_size"] = -1
464
481
 
465
- if observation_type not in env_class_map:
466
- raise ValueError(f"Unknown observation type: {observation_type}")
482
+ if observation_type not in ("object", "rgb", "color"):
483
+ raise ValueError(f"Unknown observation_type: {observation_type}")
467
484
 
468
- env_class = env_class_map[observation_type]
469
485
  config["name"] = env_id
486
+ config["observation_type"] = observation_type
470
487
 
471
- return env_class(**config)
488
+ return ForagaxEnv(**config, **kwargs)