continual-foragax 0.40.0__py3-none-any.whl → 0.41.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {continual_foragax-0.40.0.dist-info → continual_foragax-0.41.0.dist-info}/METADATA +1 -1
- {continual_foragax-0.40.0.dist-info → continual_foragax-0.41.0.dist-info}/RECORD +8 -8
- foragax/env.py +514 -389
- foragax/objects.py +10 -7
- foragax/rendering.py +9 -33
- {continual_foragax-0.40.0.dist-info → continual_foragax-0.41.0.dist-info}/WHEEL +0 -0
- {continual_foragax-0.40.0.dist-info → continual_foragax-0.41.0.dist-info}/entry_points.txt +0 -0
- {continual_foragax-0.40.0.dist-info → continual_foragax-0.41.0.dist-info}/top_level.txt +0 -0
foragax/objects.py
CHANGED
|
@@ -327,15 +327,18 @@ class FourierObject(BaseForagaxObject):
|
|
|
327
327
|
|
|
328
328
|
# Extract interleaved coefficients: [a1, b1, a2, b2, ...]
|
|
329
329
|
ab_coeffs = params[3:]
|
|
330
|
+
# Reuse existing arrays via reshaping/slices for vectorization
|
|
331
|
+
# Reshape ab_coeffs to (num_terms, 2) where col 0 is a_n, col 1 is b_n
|
|
330
332
|
n_terms = len(ab_coeffs) // 2
|
|
333
|
+
coeffs = ab_coeffs.reshape((n_terms, 2))
|
|
334
|
+
a_coeffs = coeffs[:, 0]
|
|
335
|
+
b_coeffs = coeffs[:, 1]
|
|
331
336
|
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
b_i = ab_coeffs[2 * i + 1] # b coefficient at index 2i+1
|
|
338
|
-
reward += a_i * jnp.cos(freq * t) + b_i * jnp.sin(freq * t)
|
|
337
|
+
freqs = jnp.arange(1, n_terms + 1, dtype=jnp.float32)
|
|
338
|
+
terms = freqs * t
|
|
339
|
+
|
|
340
|
+
# Calculate sum(a_n * cos(n*t) + b_n * sin(n*t))
|
|
341
|
+
reward = jnp.sum(a_coeffs * jnp.cos(terms) + b_coeffs * jnp.sin(terms))
|
|
339
342
|
|
|
340
343
|
# Apply min-max normalization to [-1, 1], then scale by base_magnitude
|
|
341
344
|
# Formula: 2 * (x - min) / (max - min) - 1
|
foragax/rendering.py
CHANGED
|
@@ -38,39 +38,15 @@ def apply_true_borders(
|
|
|
38
38
|
jax.image.ResizeMethod.NEAREST,
|
|
39
39
|
)
|
|
40
40
|
|
|
41
|
-
# Create border mask (2-pixel thick borders)
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
# Top border rows: 2 rows per cell
|
|
52
|
-
top_border_rows = cell_rows[:, None] * 24 + jnp.arange(2)[None, :]
|
|
53
|
-
top_border_rows_flat = top_border_rows.flatten()
|
|
54
|
-
|
|
55
|
-
# Bottom border rows: 2 rows per cell
|
|
56
|
-
bottom_border_rows = cell_rows[:, None] * 24 + 22 + jnp.arange(2)[None, :]
|
|
57
|
-
bottom_border_rows_flat = bottom_border_rows.flatten()
|
|
58
|
-
|
|
59
|
-
# Left border columns: 2 columns per cell
|
|
60
|
-
left_border_cols = cell_cols[:, None] * 24 + jnp.arange(2)[None, :]
|
|
61
|
-
left_border_cols_flat = left_border_cols.flatten()
|
|
62
|
-
|
|
63
|
-
# Right border columns: 2 columns per cell
|
|
64
|
-
right_border_cols = cell_cols[:, None] * 24 + 22 + jnp.arange(2)[None, :]
|
|
65
|
-
right_border_cols_flat = right_border_cols.flatten()
|
|
66
|
-
|
|
67
|
-
# Set top and bottom borders (full width rectangles)
|
|
68
|
-
all_border_rows = jnp.concatenate([top_border_rows_flat, bottom_border_rows_flat])
|
|
69
|
-
border_mask = border_mask.at[all_border_rows, :].set(True)
|
|
70
|
-
|
|
71
|
-
# Set left and right borders (full height rectangles)
|
|
72
|
-
all_border_cols = jnp.concatenate([left_border_cols_flat, right_border_cols_flat])
|
|
73
|
-
border_mask = border_mask.at[:, all_border_cols].set(True)
|
|
41
|
+
# Create border mask (2-pixel thick borders) using vectorized modulo operations
|
|
42
|
+
img_height, img_width = grid_size[0] * 24, grid_size[1] * 24
|
|
43
|
+
y_idx = jnp.arange(img_height) % 24
|
|
44
|
+
x_idx = jnp.arange(img_width) % 24
|
|
45
|
+
|
|
46
|
+
# Border pixels are those with offset 0, 1, 22, or 23 within each 24x24 cell
|
|
47
|
+
is_border_row = (y_idx < 2) | (y_idx >= 22)
|
|
48
|
+
is_border_col = (x_idx < 2) | (x_idx >= 22)
|
|
49
|
+
border_mask = is_border_row[:, None] | is_border_col[None, :]
|
|
74
50
|
|
|
75
51
|
# Apply border mask: use HSV border colors for border pixels, base colors elsewhere
|
|
76
52
|
result_img = jnp.where(border_mask[..., None], border_img, base_img)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|