flaxdiff 0.2.7__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,617 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import numpy as np
4
+ import math
5
+ import einops
6
+ import matplotlib.pyplot as plt
7
+ from matplotlib.colors import LinearSegmentedColormap
8
+ from typing import Tuple
9
+
10
+ # --- Core Hilbert Curve Logic ---
11
+
12
+ def _d2xy(n: int, d: int) -> Tuple[int, int]:
13
+ """
14
+ Convert a 1D Hilbert curve index to 2D (x, y) coordinates.
15
+ Based on the algorithm from Wikipedia / common implementations.
16
+
17
+ Args:
18
+ n: Size of the grid (must be a power of 2).
19
+ d: 1D Hilbert curve index (0 to n*n-1).
20
+
21
+ Returns:
22
+ Tuple of (x, y) coordinates (column, row).
23
+ """
24
+ x = y = 0
25
+ t = d
26
+ s = 1
27
+ while (s < n):
28
+ # Extract the two bits for the current level
29
+ rx = (t >> 1) & 1
30
+ ry = (t ^ rx) & 1 # Use XOR to determine the y bit based on d's pattern
31
+
32
+ # Rotate and flip the quadrant appropriately
33
+ if ry == 0:
34
+ if rx == 1:
35
+ x = (s - 1) - x
36
+ y = (s - 1) - y
37
+ # Swap x and y
38
+ x, y = y, x
39
+
40
+ # Add the offsets for the current quadrant
41
+ x += s * rx
42
+ y += s * ry
43
+
44
+ # Move to the next level
45
+ t >>= 2 # Equivalent to t //= 4
46
+ s <<= 1 # Equivalent to s *= 2
47
+ return x, y # Returns (column, row)
48
+
49
+ def hilbert_indices(H_P: int, W_P: int) -> jnp.ndarray:
50
+ """
51
+ Generate Hilbert curve indices for a rectangular grid of H_P x W_P patches.
52
+ The indices map Hilbert sequence order to row-major order.
53
+
54
+ Args:
55
+ H_P: Height in patches.
56
+ W_P: Width in patches.
57
+
58
+ Returns:
59
+ 1D JAX array where result[i] is the row-major index of the i-th patch
60
+ in the Hilbert curve sequence. The length of the array is the number
61
+ of valid patches (H_P * W_P).
62
+ """
63
+ # Find the smallest power of 2 that fits both dimensions
64
+ size = max(H_P, W_P)
65
+ # Calculate the order (e.g., order=3 means n=8)
66
+ order = math.ceil(math.log2(size)) if size > 0 else 0
67
+ n = 1 << order # n = 2**order
68
+
69
+ # Generate (row, col) coordinates for each index in the Hilbert curve order
70
+ # within the square n x n grid
71
+ coords_in_hilbert_order = []
72
+ total_patches_needed = H_P * W_P
73
+ if total_patches_needed == 0:
74
+ return jnp.array([], dtype=jnp.int32)
75
+
76
+ for d in range(n * n):
77
+ # Get (col, row) for Hilbert index d in the n x n grid
78
+ x, y = _d2xy(n, d)
79
+
80
+ # Keep only coordinates within the actual H_P x W_P grid
81
+ if x < W_P and y < H_P:
82
+ coords_in_hilbert_order.append((y, x)) # Store as (row, col)
83
+
84
+ # Early exit once we have all needed coordinates
85
+ if len(coords_in_hilbert_order) == total_patches_needed:
86
+ break
87
+
88
+ # Convert (row, col) pairs (which are in Hilbert order)
89
+ # to linear indices in row-major order
90
+ # indices[i] = row-major index of the i-th point in the Hilbert sequence
91
+ indices = jnp.array([r * W_P + c for r, c in coords_in_hilbert_order], dtype=jnp.int32)
92
+ return indices
93
+
94
+ def inverse_permutation(idx: jnp.ndarray, total_size: int) -> jnp.ndarray:
95
+ """
96
+ Compute the inverse permutation of the given indices.
97
+ Maps target index (e.g., row-major) back to source index (e.g., Hilbert sequence).
98
+
99
+ Args:
100
+ idx: Array where idx[i] is the target index for source index i.
101
+ (e.g., idx[h] = k, where h is Hilbert sequence index, k is row-major index)
102
+ Assumes idx contains unique values representing the target indices.
103
+ Length of idx is N (number of valid patches).
104
+ total_size: The total number of possible target indices (e.g., H_P * W_P).
105
+
106
+ Returns:
107
+ Array `inv` of size `total_size` such that inv[k] = h if idx[h] = k,
108
+ and inv[k] = -1 if target index k is not present in `idx`.
109
+ """
110
+ # Initialize inverse mapping with -1 (or another indicator for "not mapped")
111
+ inv = jnp.full((total_size,), -1, dtype=jnp.int32)
112
+
113
+ # Source indices are 0, 1, ..., N-1 (representing Hilbert sequence order)
114
+ source_indices = jnp.arange(idx.shape[0], dtype=jnp.int32)
115
+
116
+ # Set inv[target_index] = source_index
117
+ # inv.at[idx] accesses the elements of inv at the indices specified by idx
118
+ # .set(source_indices) sets these elements to the corresponding source index
119
+ inv = inv.at[idx].set(source_indices)
120
+ return inv
121
+
122
+ # --- Patching Logic ---
123
+
124
+ def patchify(x: jnp.ndarray, patch_size: int) -> jnp.ndarray:
125
+ """
126
+ Convert an image tensor to a sequence of patches in row-major order.
127
+
128
+ Args:
129
+ x: Image tensor of shape [B, H, W, C].
130
+ patch_size: Size of square patches.
131
+
132
+ Returns:
133
+ Tensor of patches [B, N, P*P*C], where N = (H/ps)*(W/ps).
134
+ """
135
+ # Check if dimensions are divisible by patch_size
136
+ B, H, W, C = x.shape
137
+ if H % patch_size != 0 or W % patch_size != 0:
138
+ raise ValueError(f"Image dimensions ({H}, {W}) must be divisible by patch_size ({patch_size})")
139
+
140
+ return einops.rearrange(
141
+ x,
142
+ 'b (h p1) (w p2) c -> b (h w) (p1 p2 c)', # (h w) becomes the sequence dim
143
+ p1=patch_size, p2=patch_size
144
+ )
145
+
146
+ def unpatchify(x: jnp.ndarray, patch_size: int, H: int, W: int, C: int) -> jnp.ndarray:
147
+ """
148
+ Convert a sequence of patches (assumed row-major) back to an image tensor.
149
+
150
+ Args:
151
+ x: Patch tensor of shape [B, N, P*P*C] where N = (H/ps) * (W/ps).
152
+ patch_size: Size of square patches.
153
+ H: Original image height.
154
+ W: Original image width.
155
+ C: Number of channels.
156
+
157
+ Returns:
158
+ Image tensor of shape [B, H, W, C].
159
+ """
160
+ H_P = H // patch_size
161
+ W_P = W // patch_size
162
+ expected_patches = H_P * W_P
163
+ actual_patches = x.shape[1]
164
+
165
+ # Ensure the input has the correct number of patches for the target dimensions
166
+ assert actual_patches == expected_patches, \
167
+ f"Number of patches ({actual_patches}) does not match expected ({expected_patches}) for H={H}, W={W}, patch_size={patch_size}"
168
+
169
+ return einops.rearrange(
170
+ x,
171
+ 'b (h w) (p1 p2 c) -> b (h p1) (w p2) c',
172
+ h=H_P, w=W_P, p1=patch_size, p2=patch_size, c=C
173
+ )
174
+
175
+ def hilbert_patchify(x: jnp.ndarray, patch_size: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
176
+ """
177
+ Extract patches from an image and reorder them according to the Hilbert curve.
178
+
179
+ Args:
180
+ x: Image tensor of shape [B, H, W, C].
181
+ patch_size: Size of square patches.
182
+
183
+ Returns:
184
+ Tuple of:
185
+ - patches_hilbert: Reordered patches tensor [B, N, P*P*C] (N = H_P * W_P).
186
+ - inv_idx: Inverse permutation indices [N] (maps row-major index to Hilbert sequence index, or -1).
187
+ """
188
+ B, H, W, C = x.shape
189
+ H_P = H // patch_size
190
+ W_P = W // patch_size
191
+ total_patches_expected = H_P * W_P
192
+
193
+ # Extract patches in row-major order
194
+ patches_row_major = patchify(x, patch_size) # Shape [B, N, P*P*C]
195
+
196
+ # Get Hilbert curve indices (maps Hilbert sequence index -> row-major index)
197
+ # idx[h] = k, where h is Hilbert index, k is row-major index
198
+ idx = hilbert_indices(H_P, W_P) # Shape [N]
199
+
200
+ # Store inverse mapping for unpatchify
201
+ # inv_idx[k] = h, where k is row-major index, h is Hilbert sequence index
202
+ inv_idx = inverse_permutation(idx, total_patches_expected) # Shape [N]
203
+
204
+ # Reorder patches according to Hilbert curve using advanced indexing
205
+ # Select the patches from patches_row_major at the row-major indices specified by idx
206
+ patches_hilbert = patches_row_major[:, idx, :] # Shape [B, N, P*P*C]
207
+
208
+ return patches_hilbert, inv_idx
209
+
210
+ def hilbert_unpatchify(x: jnp.ndarray, inv_idx: jnp.ndarray, patch_size: int, H: int, W: int, C: int) -> jnp.ndarray:
211
+ """
212
+ Restore the original row-major order of patches and convert back to image.
213
+ (Revised version to be JIT-compatible)
214
+
215
+ Args:
216
+ x: Hilbert-ordered patches tensor [B, N, P*P*C] (N = number of patches in Hilbert order).
217
+ inv_idx: Inverse permutation indices [total_patches_expected]
218
+ (maps row-major index k to Hilbert sequence index h, or -1).
219
+ patch_size: Size of square patches.
220
+ H: Original image height.
221
+ W: Original image width.
222
+ C: Number of channels.
223
+
224
+ Returns:
225
+ Image tensor of shape [B, H, W, C].
226
+ """
227
+ B = x.shape[0]
228
+ N = x.shape[1] # Number of patches provided in Hilbert order (h dimension)
229
+ patch_dim = x.shape[2]
230
+ H_P = H // patch_size
231
+ W_P = W // patch_size
232
+ total_patches_expected = H_P * W_P # Number of patches expected in output (k dimension)
233
+
234
+ # Ensure inv_idx has the expected total size
235
+ assert inv_idx.shape[0] == total_patches_expected, \
236
+ f"Inverse index size {inv_idx.shape[0]} does not match expected total patches {total_patches_expected}"
237
+
238
+ # --- JIT-compatible Scatter using Gather and Where ---
239
+
240
+ # Target shape for row-major patches
241
+ target_shape = (B, total_patches_expected, patch_dim)
242
+
243
+ # Create indices for gathering from x (Hilbert order h) based on inv_idx (map k -> h)
244
+ # inv_idx contains the 'h' index for each 'k' index.
245
+ # Clamp invalid indices (-1) to 0; we'll mask these results later.
246
+ # Values must be < N (the actual number of patches in x).
247
+ h_indices_for_gather = jnp.maximum(inv_idx, 0) # Shape [total_patches_expected]
248
+
249
+ # Define gather for one batch item: output[k] = input[h_indices[k]]
250
+ def gather_one_batch(single_x, h_indices):
251
+ # single_x: [N, D], h_indices: [K] where K = total_patches_expected
252
+ # Check bounds: Ensure indices used are within the valid range [0, N-1] of single_x
253
+ # This check might be redundant if inv_idx < N mask is applied correctly later,
254
+ # but can prevent out-of-bounds access if N is smaller than expected.
255
+ safe_h_indices = jnp.minimum(h_indices, N - 1)
256
+ return single_x[safe_h_indices, :] # Result: [K, D]
257
+
258
+ # Use vmap to gather across the batch dimension
259
+ # Result `gathered_patches` has shape [B, total_patches_expected, patch_dim]
260
+ gathered_patches = jax.vmap(gather_one_batch, in_axes=(0, None))(x, h_indices_for_gather)
261
+
262
+ # Create a mask for valid k indices (where corresponding h was valid)
263
+ # A valid h must be >= 0 and < N (number of patches provided in x).
264
+ valid_k_mask = (inv_idx >= 0) & (inv_idx < N) # Shape [total_patches_expected]
265
+
266
+ # Expand mask for broadcasting with patch dimensions: [1, K, 1]
267
+ valid_k_mask_broadcast = valid_k_mask[None, :, None]
268
+
269
+ # Use `where` to select gathered patches for valid k, and zeros otherwise.
270
+ # This is JIT-friendly as shapes are consistent.
271
+ row_major_patches = jnp.where(
272
+ valid_k_mask_broadcast,
273
+ gathered_patches,
274
+ jnp.zeros(target_shape, dtype=x.dtype) # Use explicit shape for zeros
275
+ )
276
+ # --- End JIT-compatible Scatter ---
277
+
278
+ # Convert the fully populated (or zero-padded) row-major patches back to image
279
+ return unpatchify(row_major_patches, patch_size, H, W, C)
280
+ # --- Visualization and Demo ---
281
+
282
+ def visualize_hilbert_curve(H: int, W: int, patch_size: int, figsize=(12, 5)):
283
+ """
284
+ Visualize the Hilbert curve mapping for a given image patch grid size.
285
+
286
+ Args:
287
+ H: Image height.
288
+ W: Image width.
289
+ patch_size: Size of each patch.
290
+ figsize: Figure size for the plot.
291
+
292
+ Returns:
293
+ The matplotlib Figure object.
294
+ """
295
+ H_P = H // patch_size
296
+ W_P = W // patch_size
297
+ if H_P * W_P == 0:
298
+ print("Warning: Grid dimensions are zero, cannot visualize.")
299
+ return None
300
+
301
+ # Get Hilbert curve indices (idx[i] = row-major index of i-th Hilbert point)
302
+ idx = np.array(hilbert_indices(H_P, W_P)) # Convert to numpy for plotting logic
303
+
304
+ # Create a grid representation for visualization: grid[row, col] = Hilbert sequence index
305
+ grid = np.full((H_P, W_P), -1.0) # Use float and -1 for unmapped cells
306
+ for i, idx_val in enumerate(idx):
307
+ # Convert linear row-major index to row, col
308
+ row = idx_val // W_P
309
+ col = idx_val % W_P
310
+ if 0 <= row < H_P and 0 <= col < W_P:
311
+ grid[row, col] = i # Assign Hilbert sequence index 'i'
312
+
313
+ # Create a colormap that transitions smoothly along the Hilbert path
314
+ cmap = LinearSegmentedColormap.from_list('hilbert', ['#0000FF', '#00FF00', '#FFFF00', '#FF0000']) # Blue -> Green -> Yellow -> Red
315
+
316
+ # Create subplots
317
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
318
+
319
+ # --- Plot 1: Original Grid (Row-Major Order) ---
320
+ orig_grid = np.arange(H_P * W_P).reshape((H_P, W_P))
321
+ im0 = axes[0].imshow(orig_grid, cmap='viridis', aspect='auto')
322
+ axes[0].set_title(f"Original Grid ({H_P}x{W_P})\n(Row-Major Order)")
323
+ # Add text labels for indices
324
+ for r in range(H_P):
325
+ for c in range(W_P):
326
+ axes[0].text(c, r, f'{orig_grid[r, c]}', ha='center', va='center', color='white' if orig_grid[r,c] < (H_P*W_P)/2 else 'black', fontsize=8)
327
+ axes[0].set_xticks(np.arange(W_P))
328
+ axes[0].set_yticks(np.arange(H_P))
329
+ axes[0].set_xticklabels(np.arange(W_P))
330
+ axes[0].set_yticklabels(np.arange(H_P))
331
+ plt.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04, label="Row-Major Index")
332
+
333
+ # --- Plot 2: Hilbert Curve Ordering ---
334
+ # Mask unmapped cells for visualization
335
+ masked_grid = np.ma.masked_where(grid == -1, grid)
336
+ im1 = axes[1].imshow(masked_grid, cmap=cmap, aspect='auto', vmin=0, vmax=max(0, len(idx)-1))
337
+ axes[1].set_title(f"Hilbert Curve Ordering ({len(idx)} points)")
338
+ # Add text labels for Hilbert indices
339
+ for r in range(H_P):
340
+ for c in range(W_P):
341
+ if grid[r,c] != -1:
342
+ axes[1].text(c, r, f'{int(grid[r, c])}', ha='center', va='center', color='black', fontsize=8)
343
+ axes[1].set_xticks(np.arange(W_P))
344
+ axes[1].set_yticks(np.arange(H_P))
345
+ axes[1].set_xticklabels(np.arange(W_P))
346
+ axes[1].set_yticklabels(np.arange(H_P))
347
+ plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04, label="Hilbert Sequence Index")
348
+
349
+ # Draw the actual curve connecting centers of patches in Hilbert order
350
+ if len(idx) > 1:
351
+ coords = []
352
+ # Find the (row, col) for each Hilbert index i
353
+ # This is faster than np.where in a loop for dense grids
354
+ row_col_map = {int(grid[r, c]): (r, c) for r in range(H_P) for c in range(W_P) if grid[r,c] != -1}
355
+ for i in range(len(idx)):
356
+ if i in row_col_map:
357
+ coords.append(row_col_map[i])
358
+ # Fallback (slower):
359
+ # row_indices, col_indices = np.where(grid == i)
360
+ # if len(row_indices) > 0:
361
+ # coords.append((row_indices[0], col_indices[0]))
362
+
363
+ if coords:
364
+ # Get coordinates for plotting (centers of cells)
365
+ y_coords = [r + 0.5 for r, c in coords]
366
+ x_coords = [c + 0.5 for r, c in coords]
367
+ axes[1].plot(x_coords, y_coords, color='black', linestyle='-', linewidth=1.5, alpha=0.8)
368
+ # Mark start point
369
+ axes[1].plot(x_coords[0], y_coords[0], 'go', markersize=8, label='Start (Idx 0)') # Green circle
370
+ # Mark end point
371
+ axes[1].plot(x_coords[-1], y_coords[-1], 'mo', markersize=8, label=f'End (Idx {len(idx)-1})') # Magenta circle
372
+ axes[1].legend(fontsize='small')
373
+
374
+
375
+ plt.tight_layout()
376
+ return fig
377
+
378
+ def create_patch_grid(patches_np: np.ndarray, patch_size: int, channels: int, grid_cols: int = 10, border: int = 1):
379
+ """
380
+ Create a visualization grid from a sequence of patches.
381
+
382
+ Args:
383
+ patches_np: Patch tensor [N, P*P*C] as NumPy array.
384
+ patch_size: Size of square patches (P).
385
+ channels: Number of channels (C).
386
+ grid_cols: How many patches wide the grid should be.
387
+ border: Width of the border between patches.
388
+
389
+ Returns:
390
+ Grid image as NumPy array.
391
+ """
392
+ n_patches = patches_np.shape[0]
393
+ if n_patches == 0:
394
+ return np.zeros((patch_size, patch_size, channels), dtype=patches_np.dtype)
395
+
396
+ # Reshape patches to actual images [N, P, P, C]
397
+ try:
398
+ patch_imgs = patches_np.reshape(n_patches, patch_size, patch_size, channels)
399
+ except ValueError as e:
400
+ print(f"Error reshaping patches: {e}")
401
+ print(f"Input shape: {patches_np.shape}, Expected P*P*C: {patch_size*patch_size*channels}")
402
+ # Return a placeholder or re-raise
403
+ return np.zeros((patch_size, patch_size, channels), dtype=patches_np.dtype)
404
+
405
+
406
+ # Determine grid size
407
+ grid_cols = min(grid_cols, n_patches)
408
+ grid_rows = int(np.ceil(n_patches / grid_cols))
409
+
410
+ # Create the grid canvas (add border space)
411
+ grid_h = grid_rows * (patch_size + border) - border
412
+ grid_w = grid_cols * (patch_size + border) - border
413
+
414
+ # Initialize grid (e.g., with white background)
415
+ if channels == 1:
416
+ grid = np.ones((grid_h, grid_w), dtype=patch_imgs.dtype) * 255
417
+ else:
418
+ grid = np.ones((grid_h, grid_w, channels), dtype=patch_imgs.dtype) * 255
419
+
420
+
421
+ # Fill the grid with patches
422
+ for i in range(n_patches):
423
+ row = i // grid_cols
424
+ col = i % grid_cols
425
+
426
+ # Calculate top-left corner for the patch
427
+ y_start = row * (patch_size + border)
428
+ x_start = col * (patch_size + border)
429
+
430
+ # Place the patch
431
+ if channels == 1:
432
+ grid[y_start:y_start+patch_size, x_start:x_start+patch_size] = patch_imgs[i, :, :, 0]
433
+ else:
434
+ grid[y_start:y_start+patch_size, x_start:x_start+patch_size] = patch_imgs[i]
435
+
436
+ # Clip to valid range ([0, 1] for float, [0, 255] for int)
437
+ if np.issubdtype(grid.dtype, np.floating):
438
+ grid = np.clip(grid, 0, 1)
439
+ elif np.issubdtype(grid.dtype, np.integer):
440
+ grid = np.clip(grid, 0, 255).astype(np.uint8) # Ensure uint8 for imshow
441
+
442
+ # Squeeze if grayscale
443
+ if channels == 1:
444
+ grid = grid.squeeze()
445
+
446
+ return grid
447
+
448
+
449
+ def demo_hilbert_patching(image_np: np.ndarray, patch_size: int = 8, figsize=(15, 12)):
450
+ """
451
+ Demonstrate the Hilbert curve patching process on an image.
452
+
453
+ Args:
454
+ image_np: NumPy array of shape [H, W, C] or [H, W].
455
+ patch_size: Size of square patches.
456
+ figsize: Figure size for the plot.
457
+
458
+ Returns:
459
+ Tuple of (fig_main, fig_reconstruction) matplotlib Figure objects.
460
+ """
461
+ # Handle grayscale images
462
+ if image_np.ndim == 2:
463
+ image_np = np.expand_dims(image_np, axis=-1) # Add channel dim
464
+
465
+ # Ensure image dimensions are divisible by patch_size by cropping
466
+ H_orig, W_orig, C = image_np.shape
467
+ H = (H_orig // patch_size) * patch_size
468
+ W = (W_orig // patch_size) * patch_size
469
+ if H != H_orig or W != W_orig:
470
+ print(f"Warning: Cropping image from ({H_orig}, {W_orig}) to ({H}, {W}) to be divisible by patch_size={patch_size}")
471
+ image_np = image_np[:H, :W, :]
472
+
473
+ # Convert to JAX array and add batch dimension
474
+ image = jnp.expand_dims(jnp.array(image_np), axis=0) # [1, H, W, C]
475
+ B, H, W, C = image.shape
476
+ H_P = H // patch_size
477
+ W_P = W // patch_size
478
+
479
+ print(f"Image shape: {image.shape}, Patch size: {patch_size}, Grid: {H_P}x{W_P}")
480
+
481
+ # --- Create Main Visualization Figure ---
482
+ fig, axes = plt.subplots(2, 2, figsize=figsize)
483
+
484
+ # 1. Original image (cropped)
485
+ display_img = np.array(image[0]) # Back to numpy for display
486
+ axes[0, 0].imshow(display_img.squeeze(), cmap='gray' if C==1 else None)
487
+ axes[0, 0].set_title(f"Original Image ({H}x{W})")
488
+ axes[0, 0].axis('off')
489
+
490
+ # 2. Original image with Hilbert curve overlay
491
+ axes[0, 1].imshow(display_img.squeeze(), cmap='gray' if C==1 else None)
492
+ axes[0, 1].set_title("Image with Hilbert Curve Overlay")
493
+
494
+ # Calculate Hilbert path coordinates on the image scale
495
+ idx = np.array(hilbert_indices(H_P, W_P))
496
+ if len(idx) > 0:
497
+ # Create grid to find coordinates easily
498
+ grid = np.full((H_P, W_P), -1)
499
+ for i, idx_val in enumerate(idx):
500
+ row, col = idx_val // W_P, idx_val % W_P
501
+ grid[row, col] = i
502
+
503
+ # Get patch center coordinates in Hilbert order
504
+ coords = []
505
+ row_col_map = {int(grid[r, c]): (r, c) for r in range(H_P) for c in range(W_P) if grid[r,c] != -1}
506
+ for i in range(len(idx)):
507
+ if i in row_col_map:
508
+ coords.append(row_col_map[i])
509
+
510
+ if len(coords) > 1:
511
+ # Scale coordinates to image pixel space
512
+ y_coords = [(r * patch_size + patch_size / 2) for r, c in coords]
513
+ x_coords = [(c * patch_size + patch_size / 2) for r, c in coords]
514
+ axes[0, 1].plot(x_coords, y_coords, 'r-', linewidth=1.5, alpha=0.7)
515
+ axes[0, 1].plot(x_coords[0], y_coords[0], 'go', markersize=5) # Start
516
+ axes[0, 1].plot(x_coords[-1], y_coords[-1], 'mo', markersize=5) # End
517
+ axes[0, 1].axis('off')
518
+
519
+
520
+ # 3. Apply Hilbert Patchify
521
+ patches_hilbert, inv_idx = hilbert_patchify(image, patch_size)
522
+ print(f"Hilbert patches shape: {patches_hilbert.shape}") # [B, N, P*P*C]
523
+ print(f"Inverse index shape: {inv_idx.shape}") # [total_patches_expected]
524
+
525
+ # For comparison, get row-major patches
526
+ patches_row_major = patchify(image, patch_size)
527
+ print(f"Row-major patches shape: {patches_row_major.shape}") # [B, N, P*P*C]
528
+
529
+ # Display a subset of patches in both orderings
530
+ n_display = min(60, patches_hilbert.shape[1]) # Show first N patches
531
+
532
+ # Convert JAX arrays to NumPy for visualization function
533
+ patches_hilbert_np = np.array(patches_hilbert[0, :n_display])
534
+ patches_row_major_np = np.array(patches_row_major[0, :n_display])
535
+
536
+ # Create visualization grids
537
+ patch_grid_row = create_patch_grid(patches_row_major_np, patch_size, C, grid_cols=10)
538
+ patch_grid_hil = create_patch_grid(patches_hilbert_np, patch_size, C, grid_cols=10)
539
+
540
+ axes[1, 0].imshow(patch_grid_row, cmap='gray' if C==1 else None, aspect='auto')
541
+ axes[1, 0].set_title(f"First {n_display} Patches (Row-Major Order)")
542
+ axes[1, 0].axis('off')
543
+
544
+ axes[1, 1].imshow(patch_grid_hil, cmap='gray' if C==1 else None, aspect='auto')
545
+ axes[1, 1].set_title(f"First {n_display} Patches (Hilbert Order)")
546
+ axes[1, 1].axis('off')
547
+
548
+ plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout
549
+ fig.suptitle(f"Hilbert Patching Demo (Patch Size: {patch_size}x{patch_size})", fontsize=16)
550
+
551
+
552
+ # --- Create Reconstruction Figure ---
553
+ fig2, axes2 = plt.subplots(1, 2, figsize=(12, 6))
554
+
555
+ # 4. Unpatchify and verify
556
+ reconstructed = hilbert_unpatchify(patches_hilbert, inv_idx, patch_size, H, W, C)
557
+ print(f"Reconstructed image shape: {reconstructed.shape}")
558
+
559
+ # Compute and print reconstruction error
560
+ error = jnp.mean(jnp.abs(image - reconstructed))
561
+ print(f"Reconstruction Mean Absolute Error: {error:.6f}")
562
+
563
+ # Display original and reconstructed
564
+ reconstructed_np = np.array(reconstructed[0]) # Back to numpy
565
+ axes2[0].imshow(display_img.squeeze(), cmap='gray' if C==1 else None)
566
+ axes2[0].set_title("Original Image (Cropped)")
567
+ axes2[0].axis('off')
568
+
569
+ axes2[1].imshow(reconstructed_np.squeeze(), cmap='gray' if C==1 else None)
570
+ axes2[1].set_title(f"Reconstructed from Hilbert Patches\nMAE: {error:.4f}")
571
+ axes2[1].axis('off')
572
+
573
+ plt.tight_layout()
574
+ fig2.suptitle("Image Reconstruction Verification", fontsize=16)
575
+
576
+ return fig, fig2
577
+
578
+
579
+ # --- Example Usage ---
580
+ if __name__ == '__main__':
581
+ # Create a sample image (e.g., gradient)
582
+ H, W, C = 64, 80, 3 # Rectangular image
583
+ # H, W, C = 64, 64, 1 # Square grayscale image
584
+ img_np = np.zeros((H, W, C), dtype=np.float32)
585
+ x_coords = np.linspace(0, 1, W)
586
+ y_coords = np.linspace(0, 1, H)
587
+ xv, yv = np.meshgrid(x_coords, y_coords)
588
+
589
+ if C == 3:
590
+ img_np[..., 0] = xv # Red channel varies with width
591
+ img_np[..., 1] = yv # Green channel varies with height
592
+ img_np[..., 2] = (xv + yv) / 2 # Blue channel is average
593
+ else: # Grayscale
594
+ img_np[..., 0] = (xv + yv) / 2
595
+
596
+ # --- Test Visualization ---
597
+ patch_size_vis = 16
598
+ H_vis, W_vis = 4*patch_size_vis, 5*patch_size_vis # e.g., 64x80
599
+ print(f"\nVisualizing Hilbert curve for {H_vis//patch_size_vis}x{W_vis//patch_size_vis} patch grid...")
600
+ fig_vis = visualize_hilbert_curve(H_vis, W_vis, patch_size_vis)
601
+ if fig_vis:
602
+ # fig_vis.savefig("hilbert_curve_visualization.png")
603
+ plt.show() # Display the plot
604
+
605
+ # --- Test Patching Demo ---
606
+ patch_size_demo = 8
607
+ print(f"\nRunning Hilbert patching demo with patch size {patch_size_demo}...")
608
+ fig_main, fig_recon = demo_hilbert_patching(img_np, patch_size=patch_size_demo)
609
+ # fig_main.savefig("hilbert_patching_demo.png")
610
+ # fig_recon.savefig("hilbert_reconstruction.png")
611
+ plt.show() # Display the plots
612
+
613
+ # --- Test edge case: small image ---
614
+ print("\nTesting small image (3x5 patches)...")
615
+ img_small_np = img_np[:3*4, :5*4, :] # 12x20 image
616
+ fig_main_small, fig_recon_small = demo_hilbert_patching(img_small_np, patch_size=4)
617
+ plt.show()