continual-foragax 0.40.0__py3-none-any.whl → 0.41.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foragax/objects.py CHANGED
@@ -327,15 +327,18 @@ class FourierObject(BaseForagaxObject):
327
327
 
328
328
  # Extract interleaved coefficients: [a1, b1, a2, b2, ...]
329
329
  ab_coeffs = params[3:]
330
+ # Reuse existing arrays via reshaping/slices for vectorization
331
+ # Reshape ab_coeffs to (num_terms, 2) where col 0 is a_n, col 1 is b_n
330
332
  n_terms = len(ab_coeffs) // 2
333
+ coeffs = ab_coeffs.reshape((n_terms, 2))
334
+ a_coeffs = coeffs[:, 0]
335
+ b_coeffs = coeffs[:, 1]
331
336
 
332
- # Compute Fourier series: sum(a_n*cos(n*t) + b_n*sin(n*t))
333
- reward = 0.0
334
- for i in range(n_terms):
335
- freq = i + 1
336
- a_i = ab_coeffs[2 * i] # a coefficient at index 2i
337
- b_i = ab_coeffs[2 * i + 1] # b coefficient at index 2i+1
338
- reward += a_i * jnp.cos(freq * t) + b_i * jnp.sin(freq * t)
337
+ freqs = jnp.arange(1, n_terms + 1, dtype=jnp.float32)
338
+ terms = freqs * t
339
+
340
+ # Calculate sum(a_n * cos(n*t) + b_n * sin(n*t))
341
+ reward = jnp.sum(a_coeffs * jnp.cos(terms) + b_coeffs * jnp.sin(terms))
339
342
 
340
343
  # Apply min-max normalization to [-1, 1], then scale by base_magnitude
341
344
  # Formula: 2 * (x - min) / (max - min) - 1
foragax/rendering.py CHANGED
@@ -38,39 +38,15 @@ def apply_true_borders(
38
38
  jax.image.ResizeMethod.NEAREST,
39
39
  )
40
40
 
41
- # Create border mask (2-pixel thick borders) - vectorized like grid lines
42
- height, width = grid_size
43
- img_height, img_width = height * 24, width * 24
44
-
45
- border_mask = jnp.zeros((img_height, img_width), dtype=bool)
46
-
47
- # Create border row and column indices for all cells at once
48
- cell_rows = jnp.arange(height)
49
- cell_cols = jnp.arange(width)
50
-
51
- # Top border rows: 2 rows per cell
52
- top_border_rows = cell_rows[:, None] * 24 + jnp.arange(2)[None, :]
53
- top_border_rows_flat = top_border_rows.flatten()
54
-
55
- # Bottom border rows: 2 rows per cell
56
- bottom_border_rows = cell_rows[:, None] * 24 + 22 + jnp.arange(2)[None, :]
57
- bottom_border_rows_flat = bottom_border_rows.flatten()
58
-
59
- # Left border columns: 2 columns per cell
60
- left_border_cols = cell_cols[:, None] * 24 + jnp.arange(2)[None, :]
61
- left_border_cols_flat = left_border_cols.flatten()
62
-
63
- # Right border columns: 2 columns per cell
64
- right_border_cols = cell_cols[:, None] * 24 + 22 + jnp.arange(2)[None, :]
65
- right_border_cols_flat = right_border_cols.flatten()
66
-
67
- # Set top and bottom borders (full width rectangles)
68
- all_border_rows = jnp.concatenate([top_border_rows_flat, bottom_border_rows_flat])
69
- border_mask = border_mask.at[all_border_rows, :].set(True)
70
-
71
- # Set left and right borders (full height rectangles)
72
- all_border_cols = jnp.concatenate([left_border_cols_flat, right_border_cols_flat])
73
- border_mask = border_mask.at[:, all_border_cols].set(True)
41
+ # Create border mask (2-pixel thick borders) using vectorized modulo operations
42
+ img_height, img_width = grid_size[0] * 24, grid_size[1] * 24
43
+ y_idx = jnp.arange(img_height) % 24
44
+ x_idx = jnp.arange(img_width) % 24
45
+
46
+ # Border pixels are those with offset 0, 1, 22, or 23 within each 24x24 cell
47
+ is_border_row = (y_idx < 2) | (y_idx >= 22)
48
+ is_border_col = (x_idx < 2) | (x_idx >= 22)
49
+ border_mask = is_border_row[:, None] | is_border_col[None, :]
74
50
 
75
51
  # Apply border mask: use HSV border colors for border pixels, base colors elsewhere
76
52
  result_img = jnp.where(border_mask[..., None], border_img, base_img)