continual-foragax 0.8.2__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {continual_foragax-0.8.2.dist-info → continual_foragax-0.10.0.dist-info}/METADATA +10 -1
- {continual_foragax-0.8.2.dist-info → continual_foragax-0.10.0.dist-info}/RECORD +9 -7
- foragax/colors.py +74 -0
- foragax/env.py +71 -10
- foragax/objects.py +1 -1
- foragax/rendering.py +74 -0
- {continual_foragax-0.8.2.dist-info → continual_foragax-0.10.0.dist-info}/WHEEL +0 -0
- {continual_foragax-0.8.2.dist-info → continual_foragax-0.10.0.dist-info}/entry_points.txt +0 -0
- {continual_foragax-0.8.2.dist-info → continual_foragax-0.10.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: continual-foragax
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.10.0
|
4
4
|
Summary: A continual reinforcement learning benchmark
|
5
5
|
Author-email: Steven Tang <stang5@ualberta.ca>
|
6
6
|
Requires-Python: >=3.8
|
@@ -119,3 +119,12 @@ class into registry configs or construct environments programmatically.
|
|
119
119
|
## Development
|
120
120
|
|
121
121
|
Run unit tests via pytest.
|
122
|
+
|
123
|
+
## Acknowledgments
|
124
|
+
|
125
|
+
We acknowledge the data providers in the ECA&D project. Klein Tank, A.M.G. and
|
126
|
+
Coauthors, 2002. Daily dataset of 20th-century surface air temperature and
|
127
|
+
precipitation series for the European Climate Assessment. Int. J. of Climatol.,
|
128
|
+
22, 1441-1453.
|
129
|
+
|
130
|
+
Data and metadata available at https://www.ecad.eu
|
@@ -1,7 +1,9 @@
|
|
1
1
|
foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
foragax/
|
3
|
-
foragax/
|
2
|
+
foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
|
3
|
+
foragax/env.py,sha256=OtpcyqzBOQLdTvvRegD3SYm4mi4Ga2WE5eJ7OQmQOaw,18294
|
4
|
+
foragax/objects.py,sha256=CyBxrykTxpHCI_2hE9jE8mG4TU8R7VxzKdQ5mtxkEqU,6004
|
4
5
|
foragax/registry.py,sha256=7_RDXvm_3RNO7culBLGkE0jH8Wk_q6jbMv72dZx4JO8,2722
|
6
|
+
foragax/rendering.py,sha256=KAoQpdndy5JDQlyG0c5QDHuH-_Tfy5RuVlDtndnHVjc,2765
|
5
7
|
foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
|
6
8
|
foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt,sha256=N7URbX6VlCZvCboUogYjMzy1I-0cfNPOn0QTLSHHfQ0,1776751
|
7
9
|
foragax/data/ECA_non-blended_custom/TG_SOUID100928.txt,sha256=AaJMWisVu2YPlZFwvetzHZBI2DqWqUYmp8BVzx6gWZI,817991
|
@@ -126,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
|
|
126
128
|
foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
|
127
129
|
foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
|
128
130
|
foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
|
129
|
-
continual_foragax-0.
|
130
|
-
continual_foragax-0.
|
131
|
-
continual_foragax-0.
|
132
|
-
continual_foragax-0.
|
133
|
-
continual_foragax-0.
|
131
|
+
continual_foragax-0.10.0.dist-info/METADATA,sha256=Oo3pJnjoU7VLruPXrhNT-WfzF7cGvkGN8crW-BclMsk,4897
|
132
|
+
continual_foragax-0.10.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
133
|
+
continual_foragax-0.10.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
|
134
|
+
continual_foragax-0.10.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
|
135
|
+
continual_foragax-0.10.0.dist-info/RECORD,,
|
foragax/colors.py
ADDED
@@ -0,0 +1,74 @@
|
|
1
|
+
"""Color utility functions for Foragax."""
|
2
|
+
|
3
|
+
import jax
|
4
|
+
import jax.numpy as jnp
|
5
|
+
|
6
|
+
|
7
|
+
def hsv_to_rgb(h: jax.Array, s: float = 1.0, v: float = 1.0) -> jax.Array:
|
8
|
+
"""Convert HSV color values to RGB.
|
9
|
+
|
10
|
+
Args:
|
11
|
+
h: Hue values in range [0, 1]
|
12
|
+
s: Saturation value in range [0, 1], default 1.0
|
13
|
+
v: Value (brightness) in range [0, 1], default 1.0
|
14
|
+
|
15
|
+
Returns:
|
16
|
+
RGB values as array of shape (..., 3) with values in range [0, 1]
|
17
|
+
"""
|
18
|
+
c = v * s
|
19
|
+
x = c * (1 - jnp.abs(jnp.mod(h * 6, 2) - 1))
|
20
|
+
m = v - c
|
21
|
+
|
22
|
+
# Create RGB arrays
|
23
|
+
r = jnp.zeros_like(h)
|
24
|
+
g = jnp.zeros_like(h)
|
25
|
+
b = jnp.zeros_like(h)
|
26
|
+
|
27
|
+
# Sector 0: 0-60 degrees (red to yellow)
|
28
|
+
mask0 = h < 1 / 6
|
29
|
+
r = jnp.where(mask0, c, r)
|
30
|
+
g = jnp.where(mask0, x, g)
|
31
|
+
|
32
|
+
# Sector 1: 60-120 degrees (yellow to green)
|
33
|
+
mask1 = (h >= 1 / 6) & (h < 2 / 6)
|
34
|
+
r = jnp.where(mask1, x, r)
|
35
|
+
g = jnp.where(mask1, c, g)
|
36
|
+
|
37
|
+
# Sector 2: 120-180 degrees (green to cyan)
|
38
|
+
mask2 = (h >= 2 / 6) & (h < 3 / 6)
|
39
|
+
g = jnp.where(mask2, c, g)
|
40
|
+
b = jnp.where(mask2, x, b)
|
41
|
+
|
42
|
+
# Sector 3: 180-240 degrees (cyan to blue)
|
43
|
+
mask3 = (h >= 3 / 6) & (h < 4 / 6)
|
44
|
+
g = jnp.where(mask3, x, g)
|
45
|
+
b = jnp.where(mask3, c, b)
|
46
|
+
|
47
|
+
# Sector 4: 240-300 degrees (blue to magenta)
|
48
|
+
mask4 = (h >= 4 / 6) & (h < 5 / 6)
|
49
|
+
r = jnp.where(mask4, x, r)
|
50
|
+
b = jnp.where(mask4, c, b)
|
51
|
+
|
52
|
+
# Sector 5: 300-360 degrees (magenta to red)
|
53
|
+
mask5 = h >= 5 / 6
|
54
|
+
r = jnp.where(mask5, c, r)
|
55
|
+
b = jnp.where(mask5, x, b)
|
56
|
+
|
57
|
+
# Add value offset
|
58
|
+
rgb = jnp.stack([r + m, g + m, b + m], axis=-1)
|
59
|
+
return rgb
|
60
|
+
|
61
|
+
|
62
|
+
def hsv_to_rgb_255(h: jax.Array, s: float = 0.9, v: float = 0.8) -> jax.Array:
|
63
|
+
"""Convert HSV to RGB with values scaled to 0-255 range for image rendering.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
h: Hue values in range [0, 1]
|
67
|
+
s: Saturation value, default 0.9
|
68
|
+
v: Value (brightness), default 0.8
|
69
|
+
|
70
|
+
Returns:
|
71
|
+
RGB values as uint8 array of shape (..., 3) with values in range [0, 255]
|
72
|
+
"""
|
73
|
+
rgb = hsv_to_rgb(h, s, v)
|
74
|
+
return (rgb * 255).astype(jnp.uint8)
|
foragax/env.py
CHANGED
@@ -14,6 +14,7 @@ from flax import struct
|
|
14
14
|
from gymnax.environments import environment, spaces
|
15
15
|
|
16
16
|
from foragax.objects import AGENT, EMPTY, BaseForagaxObject, WeatherObject
|
17
|
+
from foragax.rendering import apply_true_borders
|
17
18
|
from foragax.weather import get_temperature
|
18
19
|
|
19
20
|
|
@@ -295,7 +296,11 @@ class ForagaxEnv(environment.Environment):
|
|
295
296
|
@partial(jax.jit, static_argnames=("self", "render_mode"))
|
296
297
|
def render(self, state: EnvState, params: EnvParams, render_mode: str = "world"):
|
297
298
|
"""Render the environment state."""
|
298
|
-
|
299
|
+
is_world_mode = render_mode in ("world", "world_true")
|
300
|
+
is_aperture_mode = render_mode in ("aperture", "aperture_true")
|
301
|
+
is_true_mode = render_mode in ("world_true", "aperture_true")
|
302
|
+
|
303
|
+
if is_world_mode:
|
299
304
|
# Create an RGB image from the object grid
|
300
305
|
img = jnp.zeros((self.size[1], self.size[0], 3))
|
301
306
|
# Decode grid for rendering: non-negative are objects, negative are empty
|
@@ -341,15 +346,18 @@ class ForagaxEnv(environment.Environment):
|
|
341
346
|
jax.image.ResizeMethod.NEAREST,
|
342
347
|
)
|
343
348
|
|
349
|
+
if is_true_mode:
|
350
|
+
# Apply true object borders by overlaying true colors on border pixels
|
351
|
+
img = apply_true_borders(img, render_grid, self.size)
|
352
|
+
|
353
|
+
# Add grid lines for world mode
|
344
354
|
grid_color = jnp.zeros(3, dtype=jnp.uint8)
|
345
355
|
row_indices = jnp.arange(1, self.size[1]) * 24
|
346
356
|
col_indices = jnp.arange(1, self.size[0]) * 24
|
347
357
|
img = img.at[row_indices, :].set(grid_color)
|
348
358
|
img = img.at[:, col_indices].set(grid_color)
|
349
359
|
|
350
|
-
|
351
|
-
|
352
|
-
elif render_mode == "aperture":
|
360
|
+
elif is_aperture_mode:
|
353
361
|
obs_grid = jnp.maximum(0, state.object_grid)
|
354
362
|
aperture = self._get_aperture(obs_grid, state.pos)
|
355
363
|
aperture_one_hot = jax.nn.one_hot(aperture, len(self.object_ids))
|
@@ -366,36 +374,89 @@ class ForagaxEnv(environment.Environment):
|
|
366
374
|
jax.image.ResizeMethod.NEAREST,
|
367
375
|
)
|
368
376
|
|
377
|
+
if is_true_mode:
|
378
|
+
# Apply true object borders by overlaying true colors on border pixels
|
379
|
+
img = apply_true_borders(img, aperture, self.aperture_size)
|
380
|
+
|
381
|
+
# Add grid lines for aperture mode
|
369
382
|
grid_color = jnp.zeros(3, dtype=jnp.uint8)
|
370
383
|
row_indices = jnp.arange(1, self.aperture_size[0]) * 24
|
371
384
|
col_indices = jnp.arange(1, self.aperture_size[1]) * 24
|
372
385
|
img = img.at[row_indices, :].set(grid_color)
|
373
386
|
img = img.at[:, col_indices].set(grid_color)
|
374
387
|
|
375
|
-
return img
|
376
388
|
else:
|
377
389
|
raise ValueError(f"Unknown render_mode: {render_mode}")
|
378
390
|
|
391
|
+
return img
|
392
|
+
|
379
393
|
|
380
394
|
class ForagaxObjectEnv(ForagaxEnv):
|
381
395
|
"""Foragax environment with object-based aperture observation."""
|
382
396
|
|
397
|
+
def __init__(
|
398
|
+
self,
|
399
|
+
size: Union[Tuple[int, int], int] = (10, 10),
|
400
|
+
aperture_size: Union[Tuple[int, int], int] = (5, 5),
|
401
|
+
objects: Tuple[BaseForagaxObject, ...] = (),
|
402
|
+
biomes: Tuple[Biome, ...] = (Biome(object_frequencies=()),),
|
403
|
+
):
|
404
|
+
super().__init__(size, aperture_size, objects, biomes)
|
405
|
+
|
406
|
+
# Compute unique colors and mapping for partial observability
|
407
|
+
# Exclude EMPTY (index 0) from color channels
|
408
|
+
object_colors_no_empty = self.object_colors[1:]
|
409
|
+
|
410
|
+
# Find unique colors in order of first appearance
|
411
|
+
unique_colors = []
|
412
|
+
color_indices = jnp.zeros(len(object_colors_no_empty), dtype=int)
|
413
|
+
color_map = {}
|
414
|
+
next_channel = 0
|
415
|
+
|
416
|
+
for i, color in enumerate(object_colors_no_empty):
|
417
|
+
color_tuple = tuple(color.tolist())
|
418
|
+
if color_tuple not in color_map:
|
419
|
+
color_map[color_tuple] = next_channel
|
420
|
+
unique_colors.append(color)
|
421
|
+
next_channel += 1
|
422
|
+
color_indices = color_indices.at[i].set(color_map[color_tuple])
|
423
|
+
|
424
|
+
self.unique_colors = jnp.array(unique_colors)
|
425
|
+
self.num_color_channels = len(unique_colors)
|
426
|
+
# color_indices maps from object_id-1 to color_channel_index
|
427
|
+
self.object_to_color_map = color_indices
|
428
|
+
|
383
429
|
def get_obs(self, state: EnvState, params: EnvParams, key=None) -> jax.Array:
|
384
|
-
num_obj_types = len(self.object_ids)
|
385
430
|
# Decode grid for observation
|
386
431
|
obs_grid = jnp.maximum(0, state.object_grid)
|
387
432
|
aperture = self._get_aperture(obs_grid, state.pos)
|
388
433
|
aperture = jnp.flip(aperture, axis=0)
|
389
|
-
|
390
|
-
|
434
|
+
|
435
|
+
# Handle case with no objects (only EMPTY)
|
436
|
+
if self.num_color_channels == 0:
|
437
|
+
return jnp.zeros(aperture.shape + (0,), dtype=jnp.float32)
|
438
|
+
|
439
|
+
# Map object IDs to color channel indices
|
440
|
+
# aperture contains object IDs (0 = EMPTY, 1+ = objects)
|
441
|
+
# For EMPTY (0), we want no color channel activated
|
442
|
+
# For objects (1+), map to color channel using object_to_color_map
|
443
|
+
color_channels = jnp.where(
|
444
|
+
aperture == 0,
|
445
|
+
-1, # Special value for EMPTY
|
446
|
+
jnp.take(self.object_to_color_map, aperture - 1, axis=0),
|
447
|
+
)
|
448
|
+
|
449
|
+
# Create one-hot encoding for color channels
|
450
|
+
# jax.nn.one_hot produces all zeros for -1 (EMPTY positions)
|
451
|
+
obs = jax.nn.one_hot(color_channels, self.num_color_channels, axis=-1)
|
452
|
+
|
391
453
|
return obs
|
392
454
|
|
393
455
|
def observation_space(self, params: EnvParams) -> spaces.Box:
|
394
|
-
num_obj_types = len(self.object_ids)
|
395
456
|
obs_shape = (
|
396
457
|
self.aperture_size[0],
|
397
458
|
self.aperture_size[1],
|
398
|
-
|
459
|
+
self.num_color_channels,
|
399
460
|
)
|
400
461
|
return spaces.Box(0, 1, obs_shape, float)
|
401
462
|
|
foragax/objects.py
CHANGED
@@ -181,7 +181,7 @@ AGENT = DefaultForagaxObject(name="agent", blocking=True, color=(0, 0, 255))
|
|
181
181
|
|
182
182
|
|
183
183
|
def create_weather_objects(
|
184
|
-
file_index: int = 0, repeat: int =
|
184
|
+
file_index: int = 0, repeat: int = 500, multiplier: float = 1.0
|
185
185
|
):
|
186
186
|
"""Create HOT and COLD WeatherObject instances using the specified file.
|
187
187
|
|
foragax/rendering.py
ADDED
@@ -0,0 +1,74 @@
|
|
1
|
+
"""Rendering utilities for Foragax environments."""
|
2
|
+
|
3
|
+
from typing import Tuple
|
4
|
+
|
5
|
+
import jax
|
6
|
+
import jax.numpy as jnp
|
7
|
+
|
8
|
+
from foragax.colors import hsv_to_rgb_255
|
9
|
+
|
10
|
+
|
11
|
+
def apply_true_borders(
|
12
|
+
base_img: jax.Array, true_grid: jax.Array, grid_size: Tuple[int, int]
|
13
|
+
) -> jax.Array:
|
14
|
+
"""Apply true object borders by overlaying HSV border colors on border pixels.
|
15
|
+
|
16
|
+
Args:
|
17
|
+
base_img: Base image with object colors
|
18
|
+
true_grid: Grid of object IDs for determining border colors
|
19
|
+
grid_size: (height, width) of the grid
|
20
|
+
|
21
|
+
Returns:
|
22
|
+
Image with HSV borders overlaid on border pixels
|
23
|
+
"""
|
24
|
+
# Create HSV border colors for each object type
|
25
|
+
num_objects = true_grid.max() + 1 # Assume object IDs start from 0
|
26
|
+
hues = jnp.linspace(0, 1, num_objects, endpoint=False)
|
27
|
+
|
28
|
+
# Convert HSV to RGB for border colors
|
29
|
+
border_colors = hsv_to_rgb_255(hues[true_grid])
|
30
|
+
|
31
|
+
# Resize border colors to match rendered image size
|
32
|
+
border_img = jax.image.resize(
|
33
|
+
border_colors,
|
34
|
+
(grid_size[0] * 24, grid_size[1] * 24, 3),
|
35
|
+
jax.image.ResizeMethod.NEAREST,
|
36
|
+
)
|
37
|
+
|
38
|
+
# Create border mask (2-pixel thick borders) - vectorized like grid lines
|
39
|
+
height, width = grid_size
|
40
|
+
img_height, img_width = height * 24, width * 24
|
41
|
+
|
42
|
+
border_mask = jnp.zeros((img_height, img_width), dtype=bool)
|
43
|
+
|
44
|
+
# Create border row and column indices for all cells at once
|
45
|
+
cell_rows = jnp.arange(height)
|
46
|
+
cell_cols = jnp.arange(width)
|
47
|
+
|
48
|
+
# Top border rows: 2 rows per cell
|
49
|
+
top_border_rows = cell_rows[:, None] * 24 + jnp.arange(2)[None, :]
|
50
|
+
top_border_rows_flat = top_border_rows.flatten()
|
51
|
+
|
52
|
+
# Bottom border rows: 2 rows per cell
|
53
|
+
bottom_border_rows = cell_rows[:, None] * 24 + 22 + jnp.arange(2)[None, :]
|
54
|
+
bottom_border_rows_flat = bottom_border_rows.flatten()
|
55
|
+
|
56
|
+
# Left border columns: 2 columns per cell
|
57
|
+
left_border_cols = cell_cols[:, None] * 24 + jnp.arange(2)[None, :]
|
58
|
+
left_border_cols_flat = left_border_cols.flatten()
|
59
|
+
|
60
|
+
# Right border columns: 2 columns per cell
|
61
|
+
right_border_cols = cell_cols[:, None] * 24 + 22 + jnp.arange(2)[None, :]
|
62
|
+
right_border_cols_flat = right_border_cols.flatten()
|
63
|
+
|
64
|
+
# Set top and bottom borders (full width rectangles)
|
65
|
+
all_border_rows = jnp.concatenate([top_border_rows_flat, bottom_border_rows_flat])
|
66
|
+
border_mask = border_mask.at[all_border_rows, :].set(True)
|
67
|
+
|
68
|
+
# Set left and right borders (full height rectangles)
|
69
|
+
all_border_cols = jnp.concatenate([left_border_cols_flat, right_border_cols_flat])
|
70
|
+
border_mask = border_mask.at[:, all_border_cols].set(True)
|
71
|
+
|
72
|
+
# Apply border mask: use HSV border colors for border pixels, base colors elsewhere
|
73
|
+
result_img = jnp.where(border_mask[..., None], border_img, base_img)
|
74
|
+
return result_img
|
File without changes
|
File without changes
|
File without changes
|