flaxdiff 0.2.6.1__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/dataloaders.py +36 -24
- flaxdiff/data/dataset_map.py +2 -2
- flaxdiff/data/sources/base.py +12 -0
- flaxdiff/data/sources/images.py +71 -12
- flaxdiff/data/sources/videos.py +5 -0
- flaxdiff/inference/pipeline.py +9 -4
- flaxdiff/inference/utils.py +2 -2
- flaxdiff/models/common.py +1 -70
- flaxdiff/models/hilbert.py +617 -0
- flaxdiff/models/simple_dit.py +476 -0
- flaxdiff/models/simple_mmdit.py +861 -0
- flaxdiff/models/simple_vit.py +278 -117
- flaxdiff/trainer/general_diffusion_trainer.py +29 -10
- flaxdiff/trainer/simple_trainer.py +113 -19
- {flaxdiff-0.2.6.1.dist-info → flaxdiff-0.2.8.dist-info}/METADATA +1 -1
- {flaxdiff-0.2.6.1.dist-info → flaxdiff-0.2.8.dist-info}/RECORD +18 -16
- {flaxdiff-0.2.6.1.dist-info → flaxdiff-0.2.8.dist-info}/WHEEL +1 -1
- flaxdiff/models/better_uvit.py +0 -380
- {flaxdiff-0.2.6.1.dist-info → flaxdiff-0.2.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,617 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
import numpy as np
|
4
|
+
import math
|
5
|
+
import einops
|
6
|
+
import matplotlib.pyplot as plt
|
7
|
+
from matplotlib.colors import LinearSegmentedColormap
|
8
|
+
from typing import Tuple
|
9
|
+
|
10
|
+
# --- Core Hilbert Curve Logic ---
|
11
|
+
|
12
|
+
def _d2xy(n: int, d: int) -> Tuple[int, int]:
|
13
|
+
"""
|
14
|
+
Convert a 1D Hilbert curve index to 2D (x, y) coordinates.
|
15
|
+
Based on the algorithm from Wikipedia / common implementations.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
n: Size of the grid (must be a power of 2).
|
19
|
+
d: 1D Hilbert curve index (0 to n*n-1).
|
20
|
+
|
21
|
+
Returns:
|
22
|
+
Tuple of (x, y) coordinates (column, row).
|
23
|
+
"""
|
24
|
+
x = y = 0
|
25
|
+
t = d
|
26
|
+
s = 1
|
27
|
+
while (s < n):
|
28
|
+
# Extract the two bits for the current level
|
29
|
+
rx = (t >> 1) & 1
|
30
|
+
ry = (t ^ rx) & 1 # Use XOR to determine the y bit based on d's pattern
|
31
|
+
|
32
|
+
# Rotate and flip the quadrant appropriately
|
33
|
+
if ry == 0:
|
34
|
+
if rx == 1:
|
35
|
+
x = (s - 1) - x
|
36
|
+
y = (s - 1) - y
|
37
|
+
# Swap x and y
|
38
|
+
x, y = y, x
|
39
|
+
|
40
|
+
# Add the offsets for the current quadrant
|
41
|
+
x += s * rx
|
42
|
+
y += s * ry
|
43
|
+
|
44
|
+
# Move to the next level
|
45
|
+
t >>= 2 # Equivalent to t //= 4
|
46
|
+
s <<= 1 # Equivalent to s *= 2
|
47
|
+
return x, y # Returns (column, row)
|
48
|
+
|
49
|
+
def hilbert_indices(H_P: int, W_P: int) -> jnp.ndarray:
|
50
|
+
"""
|
51
|
+
Generate Hilbert curve indices for a rectangular grid of H_P x W_P patches.
|
52
|
+
The indices map Hilbert sequence order to row-major order.
|
53
|
+
|
54
|
+
Args:
|
55
|
+
H_P: Height in patches.
|
56
|
+
W_P: Width in patches.
|
57
|
+
|
58
|
+
Returns:
|
59
|
+
1D JAX array where result[i] is the row-major index of the i-th patch
|
60
|
+
in the Hilbert curve sequence. The length of the array is the number
|
61
|
+
of valid patches (H_P * W_P).
|
62
|
+
"""
|
63
|
+
# Find the smallest power of 2 that fits both dimensions
|
64
|
+
size = max(H_P, W_P)
|
65
|
+
# Calculate the order (e.g., order=3 means n=8)
|
66
|
+
order = math.ceil(math.log2(size)) if size > 0 else 0
|
67
|
+
n = 1 << order # n = 2**order
|
68
|
+
|
69
|
+
# Generate (row, col) coordinates for each index in the Hilbert curve order
|
70
|
+
# within the square n x n grid
|
71
|
+
coords_in_hilbert_order = []
|
72
|
+
total_patches_needed = H_P * W_P
|
73
|
+
if total_patches_needed == 0:
|
74
|
+
return jnp.array([], dtype=jnp.int32)
|
75
|
+
|
76
|
+
for d in range(n * n):
|
77
|
+
# Get (col, row) for Hilbert index d in the n x n grid
|
78
|
+
x, y = _d2xy(n, d)
|
79
|
+
|
80
|
+
# Keep only coordinates within the actual H_P x W_P grid
|
81
|
+
if x < W_P and y < H_P:
|
82
|
+
coords_in_hilbert_order.append((y, x)) # Store as (row, col)
|
83
|
+
|
84
|
+
# Early exit once we have all needed coordinates
|
85
|
+
if len(coords_in_hilbert_order) == total_patches_needed:
|
86
|
+
break
|
87
|
+
|
88
|
+
# Convert (row, col) pairs (which are in Hilbert order)
|
89
|
+
# to linear indices in row-major order
|
90
|
+
# indices[i] = row-major index of the i-th point in the Hilbert sequence
|
91
|
+
indices = jnp.array([r * W_P + c for r, c in coords_in_hilbert_order], dtype=jnp.int32)
|
92
|
+
return indices
|
93
|
+
|
94
|
+
def inverse_permutation(idx: jnp.ndarray, total_size: int) -> jnp.ndarray:
|
95
|
+
"""
|
96
|
+
Compute the inverse permutation of the given indices.
|
97
|
+
Maps target index (e.g., row-major) back to source index (e.g., Hilbert sequence).
|
98
|
+
|
99
|
+
Args:
|
100
|
+
idx: Array where idx[i] is the target index for source index i.
|
101
|
+
(e.g., idx[h] = k, where h is Hilbert sequence index, k is row-major index)
|
102
|
+
Assumes idx contains unique values representing the target indices.
|
103
|
+
Length of idx is N (number of valid patches).
|
104
|
+
total_size: The total number of possible target indices (e.g., H_P * W_P).
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
Array `inv` of size `total_size` such that inv[k] = h if idx[h] = k,
|
108
|
+
and inv[k] = -1 if target index k is not present in `idx`.
|
109
|
+
"""
|
110
|
+
# Initialize inverse mapping with -1 (or another indicator for "not mapped")
|
111
|
+
inv = jnp.full((total_size,), -1, dtype=jnp.int32)
|
112
|
+
|
113
|
+
# Source indices are 0, 1, ..., N-1 (representing Hilbert sequence order)
|
114
|
+
source_indices = jnp.arange(idx.shape[0], dtype=jnp.int32)
|
115
|
+
|
116
|
+
# Set inv[target_index] = source_index
|
117
|
+
# inv.at[idx] accesses the elements of inv at the indices specified by idx
|
118
|
+
# .set(source_indices) sets these elements to the corresponding source index
|
119
|
+
inv = inv.at[idx].set(source_indices)
|
120
|
+
return inv
|
121
|
+
|
122
|
+
# --- Patching Logic ---
|
123
|
+
|
124
|
+
def patchify(x: jnp.ndarray, patch_size: int) -> jnp.ndarray:
|
125
|
+
"""
|
126
|
+
Convert an image tensor to a sequence of patches in row-major order.
|
127
|
+
|
128
|
+
Args:
|
129
|
+
x: Image tensor of shape [B, H, W, C].
|
130
|
+
patch_size: Size of square patches.
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
Tensor of patches [B, N, P*P*C], where N = (H/ps)*(W/ps).
|
134
|
+
"""
|
135
|
+
# Check if dimensions are divisible by patch_size
|
136
|
+
B, H, W, C = x.shape
|
137
|
+
if H % patch_size != 0 or W % patch_size != 0:
|
138
|
+
raise ValueError(f"Image dimensions ({H}, {W}) must be divisible by patch_size ({patch_size})")
|
139
|
+
|
140
|
+
return einops.rearrange(
|
141
|
+
x,
|
142
|
+
'b (h p1) (w p2) c -> b (h w) (p1 p2 c)', # (h w) becomes the sequence dim
|
143
|
+
p1=patch_size, p2=patch_size
|
144
|
+
)
|
145
|
+
|
146
|
+
def unpatchify(x: jnp.ndarray, patch_size: int, H: int, W: int, C: int) -> jnp.ndarray:
|
147
|
+
"""
|
148
|
+
Convert a sequence of patches (assumed row-major) back to an image tensor.
|
149
|
+
|
150
|
+
Args:
|
151
|
+
x: Patch tensor of shape [B, N, P*P*C] where N = (H/ps) * (W/ps).
|
152
|
+
patch_size: Size of square patches.
|
153
|
+
H: Original image height.
|
154
|
+
W: Original image width.
|
155
|
+
C: Number of channels.
|
156
|
+
|
157
|
+
Returns:
|
158
|
+
Image tensor of shape [B, H, W, C].
|
159
|
+
"""
|
160
|
+
H_P = H // patch_size
|
161
|
+
W_P = W // patch_size
|
162
|
+
expected_patches = H_P * W_P
|
163
|
+
actual_patches = x.shape[1]
|
164
|
+
|
165
|
+
# Ensure the input has the correct number of patches for the target dimensions
|
166
|
+
assert actual_patches == expected_patches, \
|
167
|
+
f"Number of patches ({actual_patches}) does not match expected ({expected_patches}) for H={H}, W={W}, patch_size={patch_size}"
|
168
|
+
|
169
|
+
return einops.rearrange(
|
170
|
+
x,
|
171
|
+
'b (h w) (p1 p2 c) -> b (h p1) (w p2) c',
|
172
|
+
h=H_P, w=W_P, p1=patch_size, p2=patch_size, c=C
|
173
|
+
)
|
174
|
+
|
175
|
+
def hilbert_patchify(x: jnp.ndarray, patch_size: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
176
|
+
"""
|
177
|
+
Extract patches from an image and reorder them according to the Hilbert curve.
|
178
|
+
|
179
|
+
Args:
|
180
|
+
x: Image tensor of shape [B, H, W, C].
|
181
|
+
patch_size: Size of square patches.
|
182
|
+
|
183
|
+
Returns:
|
184
|
+
Tuple of:
|
185
|
+
- patches_hilbert: Reordered patches tensor [B, N, P*P*C] (N = H_P * W_P).
|
186
|
+
- inv_idx: Inverse permutation indices [N] (maps row-major index to Hilbert sequence index, or -1).
|
187
|
+
"""
|
188
|
+
B, H, W, C = x.shape
|
189
|
+
H_P = H // patch_size
|
190
|
+
W_P = W // patch_size
|
191
|
+
total_patches_expected = H_P * W_P
|
192
|
+
|
193
|
+
# Extract patches in row-major order
|
194
|
+
patches_row_major = patchify(x, patch_size) # Shape [B, N, P*P*C]
|
195
|
+
|
196
|
+
# Get Hilbert curve indices (maps Hilbert sequence index -> row-major index)
|
197
|
+
# idx[h] = k, where h is Hilbert index, k is row-major index
|
198
|
+
idx = hilbert_indices(H_P, W_P) # Shape [N]
|
199
|
+
|
200
|
+
# Store inverse mapping for unpatchify
|
201
|
+
# inv_idx[k] = h, where k is row-major index, h is Hilbert sequence index
|
202
|
+
inv_idx = inverse_permutation(idx, total_patches_expected) # Shape [N]
|
203
|
+
|
204
|
+
# Reorder patches according to Hilbert curve using advanced indexing
|
205
|
+
# Select the patches from patches_row_major at the row-major indices specified by idx
|
206
|
+
patches_hilbert = patches_row_major[:, idx, :] # Shape [B, N, P*P*C]
|
207
|
+
|
208
|
+
return patches_hilbert, inv_idx
|
209
|
+
|
210
|
+
def hilbert_unpatchify(x: jnp.ndarray, inv_idx: jnp.ndarray, patch_size: int, H: int, W: int, C: int) -> jnp.ndarray:
|
211
|
+
"""
|
212
|
+
Restore the original row-major order of patches and convert back to image.
|
213
|
+
(Revised version to be JIT-compatible)
|
214
|
+
|
215
|
+
Args:
|
216
|
+
x: Hilbert-ordered patches tensor [B, N, P*P*C] (N = number of patches in Hilbert order).
|
217
|
+
inv_idx: Inverse permutation indices [total_patches_expected]
|
218
|
+
(maps row-major index k to Hilbert sequence index h, or -1).
|
219
|
+
patch_size: Size of square patches.
|
220
|
+
H: Original image height.
|
221
|
+
W: Original image width.
|
222
|
+
C: Number of channels.
|
223
|
+
|
224
|
+
Returns:
|
225
|
+
Image tensor of shape [B, H, W, C].
|
226
|
+
"""
|
227
|
+
B = x.shape[0]
|
228
|
+
N = x.shape[1] # Number of patches provided in Hilbert order (h dimension)
|
229
|
+
patch_dim = x.shape[2]
|
230
|
+
H_P = H // patch_size
|
231
|
+
W_P = W // patch_size
|
232
|
+
total_patches_expected = H_P * W_P # Number of patches expected in output (k dimension)
|
233
|
+
|
234
|
+
# Ensure inv_idx has the expected total size
|
235
|
+
assert inv_idx.shape[0] == total_patches_expected, \
|
236
|
+
f"Inverse index size {inv_idx.shape[0]} does not match expected total patches {total_patches_expected}"
|
237
|
+
|
238
|
+
# --- JIT-compatible Scatter using Gather and Where ---
|
239
|
+
|
240
|
+
# Target shape for row-major patches
|
241
|
+
target_shape = (B, total_patches_expected, patch_dim)
|
242
|
+
|
243
|
+
# Create indices for gathering from x (Hilbert order h) based on inv_idx (map k -> h)
|
244
|
+
# inv_idx contains the 'h' index for each 'k' index.
|
245
|
+
# Clamp invalid indices (-1) to 0; we'll mask these results later.
|
246
|
+
# Values must be < N (the actual number of patches in x).
|
247
|
+
h_indices_for_gather = jnp.maximum(inv_idx, 0) # Shape [total_patches_expected]
|
248
|
+
|
249
|
+
# Define gather for one batch item: output[k] = input[h_indices[k]]
|
250
|
+
def gather_one_batch(single_x, h_indices):
|
251
|
+
# single_x: [N, D], h_indices: [K] where K = total_patches_expected
|
252
|
+
# Check bounds: Ensure indices used are within the valid range [0, N-1] of single_x
|
253
|
+
# This check might be redundant if inv_idx < N mask is applied correctly later,
|
254
|
+
# but can prevent out-of-bounds access if N is smaller than expected.
|
255
|
+
safe_h_indices = jnp.minimum(h_indices, N - 1)
|
256
|
+
return single_x[safe_h_indices, :] # Result: [K, D]
|
257
|
+
|
258
|
+
# Use vmap to gather across the batch dimension
|
259
|
+
# Result `gathered_patches` has shape [B, total_patches_expected, patch_dim]
|
260
|
+
gathered_patches = jax.vmap(gather_one_batch, in_axes=(0, None))(x, h_indices_for_gather)
|
261
|
+
|
262
|
+
# Create a mask for valid k indices (where corresponding h was valid)
|
263
|
+
# A valid h must be >= 0 and < N (number of patches provided in x).
|
264
|
+
valid_k_mask = (inv_idx >= 0) & (inv_idx < N) # Shape [total_patches_expected]
|
265
|
+
|
266
|
+
# Expand mask for broadcasting with patch dimensions: [1, K, 1]
|
267
|
+
valid_k_mask_broadcast = valid_k_mask[None, :, None]
|
268
|
+
|
269
|
+
# Use `where` to select gathered patches for valid k, and zeros otherwise.
|
270
|
+
# This is JIT-friendly as shapes are consistent.
|
271
|
+
row_major_patches = jnp.where(
|
272
|
+
valid_k_mask_broadcast,
|
273
|
+
gathered_patches,
|
274
|
+
jnp.zeros(target_shape, dtype=x.dtype) # Use explicit shape for zeros
|
275
|
+
)
|
276
|
+
# --- End JIT-compatible Scatter ---
|
277
|
+
|
278
|
+
# Convert the fully populated (or zero-padded) row-major patches back to image
|
279
|
+
return unpatchify(row_major_patches, patch_size, H, W, C)
|
280
|
+
# --- Visualization and Demo ---
|
281
|
+
|
282
|
+
def visualize_hilbert_curve(H: int, W: int, patch_size: int, figsize=(12, 5)):
|
283
|
+
"""
|
284
|
+
Visualize the Hilbert curve mapping for a given image patch grid size.
|
285
|
+
|
286
|
+
Args:
|
287
|
+
H: Image height.
|
288
|
+
W: Image width.
|
289
|
+
patch_size: Size of each patch.
|
290
|
+
figsize: Figure size for the plot.
|
291
|
+
|
292
|
+
Returns:
|
293
|
+
The matplotlib Figure object.
|
294
|
+
"""
|
295
|
+
H_P = H // patch_size
|
296
|
+
W_P = W // patch_size
|
297
|
+
if H_P * W_P == 0:
|
298
|
+
print("Warning: Grid dimensions are zero, cannot visualize.")
|
299
|
+
return None
|
300
|
+
|
301
|
+
# Get Hilbert curve indices (idx[i] = row-major index of i-th Hilbert point)
|
302
|
+
idx = np.array(hilbert_indices(H_P, W_P)) # Convert to numpy for plotting logic
|
303
|
+
|
304
|
+
# Create a grid representation for visualization: grid[row, col] = Hilbert sequence index
|
305
|
+
grid = np.full((H_P, W_P), -1.0) # Use float and -1 for unmapped cells
|
306
|
+
for i, idx_val in enumerate(idx):
|
307
|
+
# Convert linear row-major index to row, col
|
308
|
+
row = idx_val // W_P
|
309
|
+
col = idx_val % W_P
|
310
|
+
if 0 <= row < H_P and 0 <= col < W_P:
|
311
|
+
grid[row, col] = i # Assign Hilbert sequence index 'i'
|
312
|
+
|
313
|
+
# Create a colormap that transitions smoothly along the Hilbert path
|
314
|
+
cmap = LinearSegmentedColormap.from_list('hilbert', ['#0000FF', '#00FF00', '#FFFF00', '#FF0000']) # Blue -> Green -> Yellow -> Red
|
315
|
+
|
316
|
+
# Create subplots
|
317
|
+
fig, axes = plt.subplots(1, 2, figsize=figsize)
|
318
|
+
|
319
|
+
# --- Plot 1: Original Grid (Row-Major Order) ---
|
320
|
+
orig_grid = np.arange(H_P * W_P).reshape((H_P, W_P))
|
321
|
+
im0 = axes[0].imshow(orig_grid, cmap='viridis', aspect='auto')
|
322
|
+
axes[0].set_title(f"Original Grid ({H_P}x{W_P})\n(Row-Major Order)")
|
323
|
+
# Add text labels for indices
|
324
|
+
for r in range(H_P):
|
325
|
+
for c in range(W_P):
|
326
|
+
axes[0].text(c, r, f'{orig_grid[r, c]}', ha='center', va='center', color='white' if orig_grid[r,c] < (H_P*W_P)/2 else 'black', fontsize=8)
|
327
|
+
axes[0].set_xticks(np.arange(W_P))
|
328
|
+
axes[0].set_yticks(np.arange(H_P))
|
329
|
+
axes[0].set_xticklabels(np.arange(W_P))
|
330
|
+
axes[0].set_yticklabels(np.arange(H_P))
|
331
|
+
plt.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04, label="Row-Major Index")
|
332
|
+
|
333
|
+
# --- Plot 2: Hilbert Curve Ordering ---
|
334
|
+
# Mask unmapped cells for visualization
|
335
|
+
masked_grid = np.ma.masked_where(grid == -1, grid)
|
336
|
+
im1 = axes[1].imshow(masked_grid, cmap=cmap, aspect='auto', vmin=0, vmax=max(0, len(idx)-1))
|
337
|
+
axes[1].set_title(f"Hilbert Curve Ordering ({len(idx)} points)")
|
338
|
+
# Add text labels for Hilbert indices
|
339
|
+
for r in range(H_P):
|
340
|
+
for c in range(W_P):
|
341
|
+
if grid[r,c] != -1:
|
342
|
+
axes[1].text(c, r, f'{int(grid[r, c])}', ha='center', va='center', color='black', fontsize=8)
|
343
|
+
axes[1].set_xticks(np.arange(W_P))
|
344
|
+
axes[1].set_yticks(np.arange(H_P))
|
345
|
+
axes[1].set_xticklabels(np.arange(W_P))
|
346
|
+
axes[1].set_yticklabels(np.arange(H_P))
|
347
|
+
plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04, label="Hilbert Sequence Index")
|
348
|
+
|
349
|
+
# Draw the actual curve connecting centers of patches in Hilbert order
|
350
|
+
if len(idx) > 1:
|
351
|
+
coords = []
|
352
|
+
# Find the (row, col) for each Hilbert index i
|
353
|
+
# This is faster than np.where in a loop for dense grids
|
354
|
+
row_col_map = {int(grid[r, c]): (r, c) for r in range(H_P) for c in range(W_P) if grid[r,c] != -1}
|
355
|
+
for i in range(len(idx)):
|
356
|
+
if i in row_col_map:
|
357
|
+
coords.append(row_col_map[i])
|
358
|
+
# Fallback (slower):
|
359
|
+
# row_indices, col_indices = np.where(grid == i)
|
360
|
+
# if len(row_indices) > 0:
|
361
|
+
# coords.append((row_indices[0], col_indices[0]))
|
362
|
+
|
363
|
+
if coords:
|
364
|
+
# Get coordinates for plotting (centers of cells)
|
365
|
+
y_coords = [r + 0.5 for r, c in coords]
|
366
|
+
x_coords = [c + 0.5 for r, c in coords]
|
367
|
+
axes[1].plot(x_coords, y_coords, color='black', linestyle='-', linewidth=1.5, alpha=0.8)
|
368
|
+
# Mark start point
|
369
|
+
axes[1].plot(x_coords[0], y_coords[0], 'go', markersize=8, label='Start (Idx 0)') # Green circle
|
370
|
+
# Mark end point
|
371
|
+
axes[1].plot(x_coords[-1], y_coords[-1], 'mo', markersize=8, label=f'End (Idx {len(idx)-1})') # Magenta circle
|
372
|
+
axes[1].legend(fontsize='small')
|
373
|
+
|
374
|
+
|
375
|
+
plt.tight_layout()
|
376
|
+
return fig
|
377
|
+
|
378
|
+
def create_patch_grid(patches_np: np.ndarray, patch_size: int, channels: int, grid_cols: int = 10, border: int = 1):
|
379
|
+
"""
|
380
|
+
Create a visualization grid from a sequence of patches.
|
381
|
+
|
382
|
+
Args:
|
383
|
+
patches_np: Patch tensor [N, P*P*C] as NumPy array.
|
384
|
+
patch_size: Size of square patches (P).
|
385
|
+
channels: Number of channels (C).
|
386
|
+
grid_cols: How many patches wide the grid should be.
|
387
|
+
border: Width of the border between patches.
|
388
|
+
|
389
|
+
Returns:
|
390
|
+
Grid image as NumPy array.
|
391
|
+
"""
|
392
|
+
n_patches = patches_np.shape[0]
|
393
|
+
if n_patches == 0:
|
394
|
+
return np.zeros((patch_size, patch_size, channels), dtype=patches_np.dtype)
|
395
|
+
|
396
|
+
# Reshape patches to actual images [N, P, P, C]
|
397
|
+
try:
|
398
|
+
patch_imgs = patches_np.reshape(n_patches, patch_size, patch_size, channels)
|
399
|
+
except ValueError as e:
|
400
|
+
print(f"Error reshaping patches: {e}")
|
401
|
+
print(f"Input shape: {patches_np.shape}, Expected P*P*C: {patch_size*patch_size*channels}")
|
402
|
+
# Return a placeholder or re-raise
|
403
|
+
return np.zeros((patch_size, patch_size, channels), dtype=patches_np.dtype)
|
404
|
+
|
405
|
+
|
406
|
+
# Determine grid size
|
407
|
+
grid_cols = min(grid_cols, n_patches)
|
408
|
+
grid_rows = int(np.ceil(n_patches / grid_cols))
|
409
|
+
|
410
|
+
# Create the grid canvas (add border space)
|
411
|
+
grid_h = grid_rows * (patch_size + border) - border
|
412
|
+
grid_w = grid_cols * (patch_size + border) - border
|
413
|
+
|
414
|
+
# Initialize grid (e.g., with white background)
|
415
|
+
if channels == 1:
|
416
|
+
grid = np.ones((grid_h, grid_w), dtype=patch_imgs.dtype) * 255
|
417
|
+
else:
|
418
|
+
grid = np.ones((grid_h, grid_w, channels), dtype=patch_imgs.dtype) * 255
|
419
|
+
|
420
|
+
|
421
|
+
# Fill the grid with patches
|
422
|
+
for i in range(n_patches):
|
423
|
+
row = i // grid_cols
|
424
|
+
col = i % grid_cols
|
425
|
+
|
426
|
+
# Calculate top-left corner for the patch
|
427
|
+
y_start = row * (patch_size + border)
|
428
|
+
x_start = col * (patch_size + border)
|
429
|
+
|
430
|
+
# Place the patch
|
431
|
+
if channels == 1:
|
432
|
+
grid[y_start:y_start+patch_size, x_start:x_start+patch_size] = patch_imgs[i, :, :, 0]
|
433
|
+
else:
|
434
|
+
grid[y_start:y_start+patch_size, x_start:x_start+patch_size] = patch_imgs[i]
|
435
|
+
|
436
|
+
# Clip to valid range ([0, 1] for float, [0, 255] for int)
|
437
|
+
if np.issubdtype(grid.dtype, np.floating):
|
438
|
+
grid = np.clip(grid, 0, 1)
|
439
|
+
elif np.issubdtype(grid.dtype, np.integer):
|
440
|
+
grid = np.clip(grid, 0, 255).astype(np.uint8) # Ensure uint8 for imshow
|
441
|
+
|
442
|
+
# Squeeze if grayscale
|
443
|
+
if channels == 1:
|
444
|
+
grid = grid.squeeze()
|
445
|
+
|
446
|
+
return grid
|
447
|
+
|
448
|
+
|
449
|
+
def demo_hilbert_patching(image_np: np.ndarray, patch_size: int = 8, figsize=(15, 12)):
|
450
|
+
"""
|
451
|
+
Demonstrate the Hilbert curve patching process on an image.
|
452
|
+
|
453
|
+
Args:
|
454
|
+
image_np: NumPy array of shape [H, W, C] or [H, W].
|
455
|
+
patch_size: Size of square patches.
|
456
|
+
figsize: Figure size for the plot.
|
457
|
+
|
458
|
+
Returns:
|
459
|
+
Tuple of (fig_main, fig_reconstruction) matplotlib Figure objects.
|
460
|
+
"""
|
461
|
+
# Handle grayscale images
|
462
|
+
if image_np.ndim == 2:
|
463
|
+
image_np = np.expand_dims(image_np, axis=-1) # Add channel dim
|
464
|
+
|
465
|
+
# Ensure image dimensions are divisible by patch_size by cropping
|
466
|
+
H_orig, W_orig, C = image_np.shape
|
467
|
+
H = (H_orig // patch_size) * patch_size
|
468
|
+
W = (W_orig // patch_size) * patch_size
|
469
|
+
if H != H_orig or W != W_orig:
|
470
|
+
print(f"Warning: Cropping image from ({H_orig}, {W_orig}) to ({H}, {W}) to be divisible by patch_size={patch_size}")
|
471
|
+
image_np = image_np[:H, :W, :]
|
472
|
+
|
473
|
+
# Convert to JAX array and add batch dimension
|
474
|
+
image = jnp.expand_dims(jnp.array(image_np), axis=0) # [1, H, W, C]
|
475
|
+
B, H, W, C = image.shape
|
476
|
+
H_P = H // patch_size
|
477
|
+
W_P = W // patch_size
|
478
|
+
|
479
|
+
print(f"Image shape: {image.shape}, Patch size: {patch_size}, Grid: {H_P}x{W_P}")
|
480
|
+
|
481
|
+
# --- Create Main Visualization Figure ---
|
482
|
+
fig, axes = plt.subplots(2, 2, figsize=figsize)
|
483
|
+
|
484
|
+
# 1. Original image (cropped)
|
485
|
+
display_img = np.array(image[0]) # Back to numpy for display
|
486
|
+
axes[0, 0].imshow(display_img.squeeze(), cmap='gray' if C==1 else None)
|
487
|
+
axes[0, 0].set_title(f"Original Image ({H}x{W})")
|
488
|
+
axes[0, 0].axis('off')
|
489
|
+
|
490
|
+
# 2. Original image with Hilbert curve overlay
|
491
|
+
axes[0, 1].imshow(display_img.squeeze(), cmap='gray' if C==1 else None)
|
492
|
+
axes[0, 1].set_title("Image with Hilbert Curve Overlay")
|
493
|
+
|
494
|
+
# Calculate Hilbert path coordinates on the image scale
|
495
|
+
idx = np.array(hilbert_indices(H_P, W_P))
|
496
|
+
if len(idx) > 0:
|
497
|
+
# Create grid to find coordinates easily
|
498
|
+
grid = np.full((H_P, W_P), -1)
|
499
|
+
for i, idx_val in enumerate(idx):
|
500
|
+
row, col = idx_val // W_P, idx_val % W_P
|
501
|
+
grid[row, col] = i
|
502
|
+
|
503
|
+
# Get patch center coordinates in Hilbert order
|
504
|
+
coords = []
|
505
|
+
row_col_map = {int(grid[r, c]): (r, c) for r in range(H_P) for c in range(W_P) if grid[r,c] != -1}
|
506
|
+
for i in range(len(idx)):
|
507
|
+
if i in row_col_map:
|
508
|
+
coords.append(row_col_map[i])
|
509
|
+
|
510
|
+
if len(coords) > 1:
|
511
|
+
# Scale coordinates to image pixel space
|
512
|
+
y_coords = [(r * patch_size + patch_size / 2) for r, c in coords]
|
513
|
+
x_coords = [(c * patch_size + patch_size / 2) for r, c in coords]
|
514
|
+
axes[0, 1].plot(x_coords, y_coords, 'r-', linewidth=1.5, alpha=0.7)
|
515
|
+
axes[0, 1].plot(x_coords[0], y_coords[0], 'go', markersize=5) # Start
|
516
|
+
axes[0, 1].plot(x_coords[-1], y_coords[-1], 'mo', markersize=5) # End
|
517
|
+
axes[0, 1].axis('off')
|
518
|
+
|
519
|
+
|
520
|
+
# 3. Apply Hilbert Patchify
|
521
|
+
patches_hilbert, inv_idx = hilbert_patchify(image, patch_size)
|
522
|
+
print(f"Hilbert patches shape: {patches_hilbert.shape}") # [B, N, P*P*C]
|
523
|
+
print(f"Inverse index shape: {inv_idx.shape}") # [total_patches_expected]
|
524
|
+
|
525
|
+
# For comparison, get row-major patches
|
526
|
+
patches_row_major = patchify(image, patch_size)
|
527
|
+
print(f"Row-major patches shape: {patches_row_major.shape}") # [B, N, P*P*C]
|
528
|
+
|
529
|
+
# Display a subset of patches in both orderings
|
530
|
+
n_display = min(60, patches_hilbert.shape[1]) # Show first N patches
|
531
|
+
|
532
|
+
# Convert JAX arrays to NumPy for visualization function
|
533
|
+
patches_hilbert_np = np.array(patches_hilbert[0, :n_display])
|
534
|
+
patches_row_major_np = np.array(patches_row_major[0, :n_display])
|
535
|
+
|
536
|
+
# Create visualization grids
|
537
|
+
patch_grid_row = create_patch_grid(patches_row_major_np, patch_size, C, grid_cols=10)
|
538
|
+
patch_grid_hil = create_patch_grid(patches_hilbert_np, patch_size, C, grid_cols=10)
|
539
|
+
|
540
|
+
axes[1, 0].imshow(patch_grid_row, cmap='gray' if C==1 else None, aspect='auto')
|
541
|
+
axes[1, 0].set_title(f"First {n_display} Patches (Row-Major Order)")
|
542
|
+
axes[1, 0].axis('off')
|
543
|
+
|
544
|
+
axes[1, 1].imshow(patch_grid_hil, cmap='gray' if C==1 else None, aspect='auto')
|
545
|
+
axes[1, 1].set_title(f"First {n_display} Patches (Hilbert Order)")
|
546
|
+
axes[1, 1].axis('off')
|
547
|
+
|
548
|
+
plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout
|
549
|
+
fig.suptitle(f"Hilbert Patching Demo (Patch Size: {patch_size}x{patch_size})", fontsize=16)
|
550
|
+
|
551
|
+
|
552
|
+
# --- Create Reconstruction Figure ---
|
553
|
+
fig2, axes2 = plt.subplots(1, 2, figsize=(12, 6))
|
554
|
+
|
555
|
+
# 4. Unpatchify and verify
|
556
|
+
reconstructed = hilbert_unpatchify(patches_hilbert, inv_idx, patch_size, H, W, C)
|
557
|
+
print(f"Reconstructed image shape: {reconstructed.shape}")
|
558
|
+
|
559
|
+
# Compute and print reconstruction error
|
560
|
+
error = jnp.mean(jnp.abs(image - reconstructed))
|
561
|
+
print(f"Reconstruction Mean Absolute Error: {error:.6f}")
|
562
|
+
|
563
|
+
# Display original and reconstructed
|
564
|
+
reconstructed_np = np.array(reconstructed[0]) # Back to numpy
|
565
|
+
axes2[0].imshow(display_img.squeeze(), cmap='gray' if C==1 else None)
|
566
|
+
axes2[0].set_title("Original Image (Cropped)")
|
567
|
+
axes2[0].axis('off')
|
568
|
+
|
569
|
+
axes2[1].imshow(reconstructed_np.squeeze(), cmap='gray' if C==1 else None)
|
570
|
+
axes2[1].set_title(f"Reconstructed from Hilbert Patches\nMAE: {error:.4f}")
|
571
|
+
axes2[1].axis('off')
|
572
|
+
|
573
|
+
plt.tight_layout()
|
574
|
+
fig2.suptitle("Image Reconstruction Verification", fontsize=16)
|
575
|
+
|
576
|
+
return fig, fig2
|
577
|
+
|
578
|
+
|
579
|
+
# --- Example Usage ---
|
580
|
+
if __name__ == '__main__':
|
581
|
+
# Create a sample image (e.g., gradient)
|
582
|
+
H, W, C = 64, 80, 3 # Rectangular image
|
583
|
+
# H, W, C = 64, 64, 1 # Square grayscale image
|
584
|
+
img_np = np.zeros((H, W, C), dtype=np.float32)
|
585
|
+
x_coords = np.linspace(0, 1, W)
|
586
|
+
y_coords = np.linspace(0, 1, H)
|
587
|
+
xv, yv = np.meshgrid(x_coords, y_coords)
|
588
|
+
|
589
|
+
if C == 3:
|
590
|
+
img_np[..., 0] = xv # Red channel varies with width
|
591
|
+
img_np[..., 1] = yv # Green channel varies with height
|
592
|
+
img_np[..., 2] = (xv + yv) / 2 # Blue channel is average
|
593
|
+
else: # Grayscale
|
594
|
+
img_np[..., 0] = (xv + yv) / 2
|
595
|
+
|
596
|
+
# --- Test Visualization ---
|
597
|
+
patch_size_vis = 16
|
598
|
+
H_vis, W_vis = 4*patch_size_vis, 5*patch_size_vis # e.g., 64x80
|
599
|
+
print(f"\nVisualizing Hilbert curve for {H_vis//patch_size_vis}x{W_vis//patch_size_vis} patch grid...")
|
600
|
+
fig_vis = visualize_hilbert_curve(H_vis, W_vis, patch_size_vis)
|
601
|
+
if fig_vis:
|
602
|
+
# fig_vis.savefig("hilbert_curve_visualization.png")
|
603
|
+
plt.show() # Display the plot
|
604
|
+
|
605
|
+
# --- Test Patching Demo ---
|
606
|
+
patch_size_demo = 8
|
607
|
+
print(f"\nRunning Hilbert patching demo with patch size {patch_size_demo}...")
|
608
|
+
fig_main, fig_recon = demo_hilbert_patching(img_np, patch_size=patch_size_demo)
|
609
|
+
# fig_main.savefig("hilbert_patching_demo.png")
|
610
|
+
# fig_recon.savefig("hilbert_reconstruction.png")
|
611
|
+
plt.show() # Display the plots
|
612
|
+
|
613
|
+
# --- Test edge case: small image ---
|
614
|
+
print("\nTesting small image (3x5 patches)...")
|
615
|
+
img_small_np = img_np[:3*4, :5*4, :] # 12x20 image
|
616
|
+
fig_main_small, fig_recon_small = demo_hilbert_patching(img_small_np, patch_size=4)
|
617
|
+
plt.show()
|