flaxdiff 0.2.7__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/dataloaders.py +36 -24
- flaxdiff/data/dataset_map.py +2 -2
- flaxdiff/data/sources/base.py +12 -0
- flaxdiff/data/sources/images.py +68 -11
- flaxdiff/data/sources/videos.py +5 -0
- flaxdiff/models/common.py +1 -70
- flaxdiff/models/hilbert.py +617 -0
- flaxdiff/models/simple_dit.py +476 -0
- flaxdiff/models/simple_mmdit.py +861 -0
- flaxdiff/models/simple_vit.py +278 -117
- flaxdiff/trainer/general_diffusion_trainer.py +29 -10
- flaxdiff/trainer/simple_trainer.py +113 -19
- {flaxdiff-0.2.7.dist-info → flaxdiff-0.2.8.dist-info}/METADATA +1 -1
- {flaxdiff-0.2.7.dist-info → flaxdiff-0.2.8.dist-info}/RECORD +16 -14
- {flaxdiff-0.2.7.dist-info → flaxdiff-0.2.8.dist-info}/WHEEL +1 -1
- flaxdiff/models/better_uvit.py +0 -380
- {flaxdiff-0.2.7.dist-info → flaxdiff-0.2.8.dist-info}/top_level.txt +0 -0
flaxdiff/models/simple_vit.py
CHANGED
@@ -10,15 +10,19 @@ from flaxdiff.models.simple_unet import FourierEmbedding, TimeProjection, ConvLa
|
|
10
10
|
import einops
|
11
11
|
from flax.typing import Dtype, PrecisionLike
|
12
12
|
from functools import partial
|
13
|
-
from .
|
13
|
+
from .hilbert import hilbert_indices, inverse_permutation, hilbert_patchify, hilbert_unpatchify
|
14
|
+
|
14
15
|
|
15
16
|
def unpatchify(x, channels=3):
|
16
17
|
patch_size = int((x.shape[2] // channels) ** 0.5)
|
17
18
|
h = w = int(x.shape[1] ** .5)
|
18
|
-
assert h * w == x.shape[1] and patch_size ** 2 *
|
19
|
-
|
19
|
+
assert h * w == x.shape[1] and patch_size ** 2 * \
|
20
|
+
channels == x.shape[2], f"Invalid shape: {x.shape}, should be {h*w}, {patch_size**2*channels}"
|
21
|
+
x = einops.rearrange(
|
22
|
+
x, 'B (h w) (p1 p2 C) -> B (h p1) (w p2) C', h=h, p1=patch_size, p2=patch_size)
|
20
23
|
return x
|
21
24
|
|
25
|
+
|
22
26
|
class PatchEmbedding(nn.Module):
|
23
27
|
patch_size: int
|
24
28
|
embedding_dim: int
|
@@ -29,15 +33,16 @@ class PatchEmbedding(nn.Module):
|
|
29
33
|
def __call__(self, x):
|
30
34
|
batch, height, width, channels = x.shape
|
31
35
|
assert height % self.patch_size == 0 and width % self.patch_size == 0, "Image dimensions must be divisible by patch size"
|
32
|
-
|
33
|
-
x = nn.Conv(features=self.embedding_dim,
|
34
|
-
kernel_size=(self.patch_size, self.patch_size),
|
36
|
+
|
37
|
+
x = nn.Conv(features=self.embedding_dim,
|
38
|
+
kernel_size=(self.patch_size, self.patch_size),
|
35
39
|
strides=(self.patch_size, self.patch_size),
|
36
40
|
dtype=self.dtype,
|
37
41
|
precision=self.precision)(x)
|
38
42
|
x = jnp.reshape(x, (batch, -1, self.embedding_dim))
|
39
43
|
return x
|
40
44
|
|
45
|
+
|
41
46
|
class PositionalEncoding(nn.Module):
|
42
47
|
max_len: int
|
43
48
|
embedding_dim: int
|
@@ -49,138 +54,294 @@ class PositionalEncoding(nn.Module):
|
|
49
54
|
(1, self.max_len, self.embedding_dim))
|
50
55
|
return x + pe[:, :x.shape[1], :]
|
51
56
|
|
57
|
+
|
52
58
|
class UViT(nn.Module):
|
53
|
-
output_channels:int=3
|
59
|
+
output_channels: int = 3
|
54
60
|
patch_size: int = 16
|
55
|
-
emb_features:int=768
|
56
|
-
num_layers: int = 12
|
61
|
+
emb_features: int = 768
|
62
|
+
num_layers: int = 12 # Should be even for U-Net structure
|
57
63
|
num_heads: int = 12
|
58
|
-
dropout_rate: float = 0.1
|
59
|
-
use_projection: bool = False
|
60
|
-
use_flash_attention: bool = False
|
64
|
+
dropout_rate: float = 0.1 # Dropout is often 0 in diffusion models
|
65
|
+
use_projection: bool = False # In TransformerBlock MLP
|
66
|
+
use_flash_attention: bool = False # Passed to TransformerBlock
|
67
|
+
# Passed to TransformerBlock (likely False for UViT)
|
61
68
|
use_self_and_cross: bool = False
|
62
|
-
force_fp32_for_softmax: bool = True
|
63
|
-
|
64
|
-
|
65
|
-
|
69
|
+
force_fp32_for_softmax: bool = True # Passed to TransformerBlock
|
70
|
+
# Used in final convs if add_residualblock_output
|
71
|
+
activation: Callable = jax.nn.swish
|
72
|
+
norm_groups: int = 8
|
73
|
+
dtype: Optional[Dtype] = None # e.g., jnp.float32 or jnp.bfloat16
|
66
74
|
precision: PrecisionLike = None
|
67
75
|
add_residualblock_output: bool = False
|
68
|
-
norm_inputs: bool = False
|
69
|
-
explicitly_add_residual: bool = True
|
70
|
-
norm_epsilon: float = 1e-
|
71
|
-
use_hilbert: bool = False
|
76
|
+
norm_inputs: bool = False # Passed to TransformerBlock
|
77
|
+
explicitly_add_residual: bool = True # Passed to TransformerBlock
|
78
|
+
norm_epsilon: float = 1e-5 # Adjusted default
|
79
|
+
use_hilbert: bool = False # Toggle Hilbert patch reorder
|
80
|
+
use_remat: bool = False # Add flag to use remat
|
72
81
|
|
73
82
|
def setup(self):
|
83
|
+
assert self.num_layers % 2 == 0, "num_layers must be even for U-Net structure"
|
84
|
+
half_layers = self.num_layers // 2
|
85
|
+
|
86
|
+
# --- Norm Layer ---
|
74
87
|
if self.norm_groups > 0:
|
75
|
-
|
88
|
+
# GroupNorm needs features arg, which varies. Define partial here, apply in __call__?
|
89
|
+
# Or maybe use LayerNorm/RMSNorm consistently? Let's use LayerNorm for simplicity here.
|
90
|
+
# If GroupNorm is essential, it needs careful handling with changing feature sizes.
|
91
|
+
# self.norm_factory = partial(nn.GroupNorm, self.norm_groups, epsilon=self.norm_epsilon, dtype=self.dtype)
|
92
|
+
print(f"Warning: norm_groups > 0 not fully supported with standard LayerNorm fallback in UViT setup. Using LayerNorm.")
|
93
|
+
self.norm_factory = partial(
|
94
|
+
nn.LayerNorm, epsilon=self.norm_epsilon, dtype=self.dtype)
|
95
|
+
else:
|
96
|
+
# Use LayerNorm or RMSNorm for sequence normalization
|
97
|
+
# self.norm_factory = partial(nn.RMSNorm, epsilon=self.norm_epsilon, dtype=self.dtype)
|
98
|
+
self.norm_factory = partial(
|
99
|
+
nn.LayerNorm, epsilon=self.norm_epsilon, dtype=self.dtype)
|
100
|
+
|
101
|
+
# --- Input Path ---
|
102
|
+
self.patch_embed = PatchEmbedding(
|
103
|
+
patch_size=self.patch_size,
|
104
|
+
embedding_dim=self.emb_features,
|
105
|
+
dtype=self.dtype,
|
106
|
+
precision=self.precision,
|
107
|
+
name="patch_embed"
|
108
|
+
)
|
109
|
+
if self.use_hilbert:
|
110
|
+
# Projection layer needed after raw Hilbert patches
|
111
|
+
self.hilbert_proj = nn.Dense(
|
112
|
+
features=self.emb_features,
|
113
|
+
dtype=self.dtype,
|
114
|
+
precision=self.precision,
|
115
|
+
name="hilbert_projection"
|
116
|
+
)
|
117
|
+
|
118
|
+
# Positional encoding (learned) - applied only to patch tokens
|
119
|
+
# Max length needs to accommodate max possible patches
|
120
|
+
# Example: 512x512 image, patch 16 -> (512/16)^2 = 1024 patches
|
121
|
+
# Estimate max patches, adjust if needed
|
122
|
+
max_patches = (512 // self.patch_size)**2
|
123
|
+
self.pos_encoding = self.param('pos_encoding',
|
124
|
+
# Standard init for ViT pos embeds
|
125
|
+
jax.nn.initializers.normal(stddev=0.02),
|
126
|
+
(1, max_patches, self.emb_features))
|
127
|
+
|
128
|
+
# --- Conditioning ---
|
129
|
+
self.time_embed = nn.Sequential([
|
130
|
+
FourierEmbedding(features=self.emb_features),
|
131
|
+
TimeProjection(features=self.emb_features)
|
132
|
+
], name="time_embed")
|
133
|
+
|
134
|
+
# Text projection
|
135
|
+
self.text_proj = nn.DenseGeneral(
|
136
|
+
features=self.emb_features,
|
137
|
+
dtype=self.dtype,
|
138
|
+
precision=self.precision,
|
139
|
+
name="text_proj"
|
140
|
+
)
|
141
|
+
|
142
|
+
# --- Transformer Blocks ---
|
143
|
+
BlockClass = TransformerBlock
|
144
|
+
|
145
|
+
self.down_blocks = [
|
146
|
+
BlockClass(
|
147
|
+
heads=self.num_heads,
|
148
|
+
dim_head=self.emb_features // self.num_heads,
|
149
|
+
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
150
|
+
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross,
|
151
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax,
|
152
|
+
only_pure_attention=False, norm_inputs=self.norm_inputs,
|
153
|
+
explicitly_add_residual=self.explicitly_add_residual,
|
154
|
+
norm_epsilon=self.norm_epsilon,
|
155
|
+
name=f"down_block_{i}"
|
156
|
+
) for i in range(half_layers)
|
157
|
+
]
|
158
|
+
|
159
|
+
self.mid_block = BlockClass(
|
160
|
+
heads=self.num_heads,
|
161
|
+
dim_head=self.emb_features // self.num_heads,
|
162
|
+
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
163
|
+
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross,
|
164
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax,
|
165
|
+
only_pure_attention=False, norm_inputs=self.norm_inputs,
|
166
|
+
explicitly_add_residual=self.explicitly_add_residual,
|
167
|
+
norm_epsilon=self.norm_epsilon,
|
168
|
+
name="mid_block"
|
169
|
+
)
|
170
|
+
|
171
|
+
self.up_dense = [
|
172
|
+
nn.DenseGeneral( # Project concatenated skip + up_path features back to emb_features
|
173
|
+
features=self.emb_features,
|
174
|
+
dtype=self.dtype,
|
175
|
+
precision=self.precision,
|
176
|
+
name=f"up_dense_{i}"
|
177
|
+
) for i in range(half_layers)
|
178
|
+
]
|
179
|
+
self.up_blocks = [
|
180
|
+
BlockClass(
|
181
|
+
heads=self.num_heads,
|
182
|
+
dim_head=self.emb_features // self.num_heads,
|
183
|
+
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
184
|
+
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross,
|
185
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax,
|
186
|
+
only_pure_attention=False, norm_inputs=self.norm_inputs,
|
187
|
+
explicitly_add_residual=self.explicitly_add_residual,
|
188
|
+
norm_epsilon=self.norm_epsilon,
|
189
|
+
name=f"up_block_{i}"
|
190
|
+
) for i in range(half_layers)
|
191
|
+
]
|
192
|
+
|
193
|
+
# --- Output Path ---
|
194
|
+
self.final_norm = self.norm_factory(name="final_norm") # Use factory
|
195
|
+
|
196
|
+
patch_dim = self.patch_size ** 2 * self.output_channels
|
197
|
+
self.final_proj = nn.Dense(
|
198
|
+
features=patch_dim,
|
199
|
+
dtype=self.dtype, # Keep model dtype for projection
|
200
|
+
precision=self.precision,
|
201
|
+
kernel_init=nn.initializers.zeros, # Zero init final layer
|
202
|
+
name="final_proj"
|
203
|
+
)
|
204
|
+
|
205
|
+
if self.add_residualblock_output:
|
206
|
+
# Define these layers only if needed
|
207
|
+
self.final_conv1 = ConvLayer(
|
208
|
+
"conv",
|
209
|
+
features=64, kernel_size=(3, 3), strides=(1, 1),
|
210
|
+
dtype=self.dtype, precision=self.precision, name="final_conv1"
|
211
|
+
)
|
212
|
+
self.final_norm_conv = self.norm_factory(
|
213
|
+
name="final_norm_conv") # Use factory
|
214
|
+
self.final_conv2 = ConvLayer(
|
215
|
+
"conv",
|
216
|
+
features=self.output_channels, kernel_size=(3, 3), strides=(1, 1),
|
217
|
+
dtype=jnp.float32, # Often good to have final conv output float32
|
218
|
+
precision=self.precision, name="final_conv2"
|
219
|
+
)
|
76
220
|
else:
|
77
|
-
|
78
|
-
|
221
|
+
# Final conv to map features to output channels directly after unpatchify
|
222
|
+
self.final_conv_direct = ConvLayer(
|
223
|
+
"conv",
|
224
|
+
# Use 1x1 conv
|
225
|
+
features=self.output_channels, kernel_size=(1, 1), strides=(1, 1),
|
226
|
+
dtype=jnp.float32, # Output float32
|
227
|
+
precision=self.precision, name="final_conv_direct"
|
228
|
+
)
|
229
|
+
|
79
230
|
@nn.compact
|
80
231
|
def __call__(self, x, temb, textcontext=None):
|
81
|
-
#
|
82
|
-
temb = FourierEmbedding(features=self.emb_features)(temb)
|
83
|
-
temb = TimeProjection(features=self.emb_features)(temb)
|
84
|
-
|
85
|
-
original_img = x
|
232
|
+
original_img = x # Keep original for potential residual connection
|
86
233
|
B, H, W, C = original_img.shape
|
87
234
|
H_P = H // self.patch_size
|
88
235
|
W_P = W // self.patch_size
|
236
|
+
num_patches = H_P * W_P
|
237
|
+
assert H % self.patch_size == 0 and W % self.patch_size == 0, "Image dimensions must be divisible by patch size"
|
89
238
|
|
90
|
-
# Patch
|
91
|
-
|
92
|
-
dtype=self.dtype, precision=self.precision)(x)
|
93
|
-
num_patches = x.shape[1]
|
94
|
-
|
95
|
-
# Optional Hilbert reorder
|
239
|
+
# --- Patch Embedding ---
|
240
|
+
hilbert_inv_idx = None
|
96
241
|
if self.use_hilbert:
|
242
|
+
# Use hilbert_patchify to get raw patches and inverse index
|
243
|
+
patches_raw, hilbert_inv_idx_calc = hilbert_patchify(
|
244
|
+
x, self.patch_size) # Shape [B, S, P*P*C]
|
245
|
+
# Project raw patches
|
246
|
+
# Shape [B, S, emb_features]
|
247
|
+
x_patches = self.hilbert_proj(patches_raw)
|
248
|
+
# Calculate inverse permutation (needs total_size)
|
97
249
|
idx = hilbert_indices(H_P, W_P)
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
250
|
+
hilbert_inv_idx = inverse_permutation(
|
251
|
+
idx, total_size=num_patches) # Corrected call
|
252
|
+
# Apply Hilbert reordering *after* projection
|
253
|
+
x_patches = x_patches[:, idx, :]
|
254
|
+
else:
|
255
|
+
# Standard patch embedding
|
256
|
+
# Shape: [B, num_patches, emb_features]
|
257
|
+
x_patches = self.patch_embed(x)
|
258
|
+
|
259
|
+
# --- Positional Encoding ---
|
260
|
+
# Add positional encoding only to patch tokens
|
261
|
+
assert num_patches <= self.pos_encoding.shape[
|
262
|
+
1], f"Number of patches {num_patches} exceeds max_len {self.pos_encoding.shape[1]} in positional encoding"
|
263
|
+
x_patches = x_patches + self.pos_encoding[:, :num_patches, :]
|
264
|
+
|
265
|
+
# --- Conditioning Tokens ---
|
266
|
+
# Time embedding: [B, D] -> [B, 1, D]
|
267
|
+
time_token = self.time_embed(temb.astype(
|
268
|
+
jnp.float32)) # Ensure input is float32
|
269
|
+
time_token = jnp.expand_dims(time_token.astype(
|
270
|
+
self.dtype), axis=1) # Cast back and add seq dim
|
271
|
+
|
272
|
+
# Text embedding: [B, S_text, D_in] -> [B, S_text, D]
|
273
|
+
if textcontext is not None:
|
274
|
+
text_tokens = self.text_proj(
|
275
|
+
textcontext.astype(self.dtype)) # Cast context
|
276
|
+
num_text_tokens = text_tokens.shape[1]
|
277
|
+
# Concatenate: [Patches+Pos, Time, Text]
|
278
|
+
x = jnp.concatenate([x_patches, time_token, text_tokens], axis=1)
|
279
|
+
else:
|
280
|
+
# Concatenate: [Patches+Pos, Time]
|
281
|
+
num_text_tokens = 0
|
282
|
+
x = jnp.concatenate([x_patches, time_token], axis=1)
|
283
|
+
|
284
|
+
# --- U-Net Transformer ---
|
112
285
|
skips = []
|
113
|
-
#
|
286
|
+
# Down blocks (Encoder)
|
114
287
|
for i in range(self.num_layers // 2):
|
115
|
-
x =
|
116
|
-
|
117
|
-
|
118
|
-
only_pure_attention=False,
|
119
|
-
norm_inputs=self.norm_inputs,
|
120
|
-
explicitly_add_residual=self.explicitly_add_residual,
|
121
|
-
norm_epsilon=self.norm_epsilon, # Pass epsilon
|
122
|
-
)(x)
|
123
|
-
skips.append(x)
|
124
|
-
|
288
|
+
x = self.down_blocks[i](x) # Pass full sequence (patches+cond)
|
289
|
+
skips.append(x) # Store output for skip connection
|
290
|
+
|
125
291
|
# Middle block
|
126
|
-
x =
|
127
|
-
|
128
|
-
|
129
|
-
only_pure_attention=False,
|
130
|
-
norm_inputs=self.norm_inputs,
|
131
|
-
explicitly_add_residual=self.explicitly_add_residual,
|
132
|
-
norm_epsilon=self.norm_epsilon, # Pass epsilon
|
133
|
-
)(x)
|
134
|
-
|
135
|
-
# Out blocks
|
292
|
+
x = self.mid_block(x)
|
293
|
+
|
294
|
+
# Up blocks (Decoder)
|
136
295
|
for i in range(self.num_layers // 2):
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
x = self.norm()(x) # Uses norm_epsilon defined in setup
|
150
|
-
|
151
|
-
patch_dim = self.patch_size ** 2 * self.output_channels
|
152
|
-
x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision)(x)
|
153
|
-
# If Hilbert, restore original patch order
|
154
|
-
if self.use_hilbert:
|
155
|
-
x = x[:, inv_idx, :]
|
296
|
+
skip_conn = skips.pop()
|
297
|
+
# Concatenate along feature dimension
|
298
|
+
x = jnp.concatenate([x, skip_conn], axis=-1)
|
299
|
+
# Project back to emb_features
|
300
|
+
x = self.up_dense[i](x)
|
301
|
+
# Apply transformer block
|
302
|
+
x = self.up_blocks[i](x)
|
303
|
+
|
304
|
+
# --- Output Processing ---
|
305
|
+
# Normalize before final projection
|
306
|
+
x = self.final_norm(x) # Apply norm factory instance
|
307
|
+
|
156
308
|
# Extract only the image patch tokens (first num_patches tokens)
|
157
|
-
|
158
|
-
|
159
|
-
|
309
|
+
# Conditioning tokens (time, text) are discarded here
|
310
|
+
x_patches_out = x[:, :num_patches, :]
|
311
|
+
|
312
|
+
# Project to patch pixel dimensions
|
313
|
+
# Shape: [B, num_patches, patch_dim]
|
314
|
+
x_patches_out = self.final_proj(x_patches_out)
|
315
|
+
|
316
|
+
# --- Unpatchify ---
|
317
|
+
if self.use_hilbert:
|
318
|
+
# Restore Hilbert order to row-major order and then to image
|
319
|
+
assert hilbert_inv_idx is not None, "Hilbert inverse index missing"
|
320
|
+
x_image = hilbert_unpatchify(
|
321
|
+
x_patches_out, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
|
322
|
+
else:
|
323
|
+
# Standard unpatchify
|
324
|
+
# Shape: [B, H, W, C_out]
|
325
|
+
x_image = unpatchify(x_patches_out, channels=self.output_channels)
|
326
|
+
|
327
|
+
# --- Final Convolutions ---
|
160
328
|
if self.add_residualblock_output:
|
161
|
-
# Concatenate the original image
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
kernel_size=(3, 3),
|
181
|
-
strides=(1, 1),
|
182
|
-
# activation=jax.nn.mish
|
183
|
-
dtype=self.dtype,
|
184
|
-
precision=self.precision
|
185
|
-
)(x)
|
186
|
-
return x
|
329
|
+
# Concatenate the original image (ensure dtype matches)
|
330
|
+
x_image = jnp.concatenate(
|
331
|
+
[original_img.astype(self.dtype), x_image], axis=-1)
|
332
|
+
|
333
|
+
x_image = self.final_conv1(x_image)
|
334
|
+
# Apply norm factory instance
|
335
|
+
x_image = self.final_norm_conv(x_image)
|
336
|
+
x_image = self.activation(x_image)
|
337
|
+
x_image = self.final_conv2(x_image) # Outputs float32
|
338
|
+
else:
|
339
|
+
# Apply a simple 1x1 conv to map features if needed (unpatchify already gives C_out channels)
|
340
|
+
# Or just return x_image if channels match output_channels
|
341
|
+
# If unpatchify output channels == self.output_channels, this might be redundant
|
342
|
+
# Let's assume unpatchify gives correct channels, but ensure float32
|
343
|
+
# x_image = self.final_conv_direct(x_image) # Use 1x1 conv if needed
|
344
|
+
pass # Assuming unpatchify output is correct
|
345
|
+
|
346
|
+
# Ensure final output is float32
|
347
|
+
return x_image
|
@@ -129,6 +129,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
129
129
|
frames_per_sample: int = None,
|
130
130
|
wandb_config: Dict[str, Any] = None,
|
131
131
|
eval_metrics: List[EvaluationMetric] = None,
|
132
|
+
best_tracker_metric: str = "train/best_loss",
|
132
133
|
**kwargs
|
133
134
|
):
|
134
135
|
"""
|
@@ -196,6 +197,8 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
196
197
|
**kwargs
|
197
198
|
)
|
198
199
|
|
200
|
+
self.best_tracker_metric = best_tracker_metric
|
201
|
+
|
199
202
|
# Store video-specific parameters
|
200
203
|
self.frames_per_sample = frames_per_sample
|
201
204
|
|
@@ -203,6 +206,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
203
206
|
self.conditional_inputs = input_config.conditions
|
204
207
|
# Determine if we're working with video or images
|
205
208
|
self.is_video = self._is_video_data()
|
209
|
+
self.best_val_metrics = {}
|
206
210
|
|
207
211
|
def _is_video_data(self):
|
208
212
|
sample_data_shape = self.input_config.sample_data_shape
|
@@ -423,7 +427,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
423
427
|
process_index = jax.process_index()
|
424
428
|
generate_samples = val_step_fn
|
425
429
|
|
426
|
-
val_ds = iter(val_ds
|
430
|
+
val_ds = iter(val_ds) if val_ds else None
|
427
431
|
# Evaluation step
|
428
432
|
try:
|
429
433
|
metrics = {metric.name: [] for metric in self.eval_metrics} if self.eval_metrics else {}
|
@@ -465,11 +469,17 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
465
469
|
else: # [B,H,W,C] - Image data
|
466
470
|
self._log_image_samples(samples, current_step)
|
467
471
|
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
472
|
+
# Flatten the metrics
|
473
|
+
if metrics:
|
474
|
+
metrics = {k: np.mean(v) for k, v in metrics.items()}
|
475
|
+
# Update the best validation metrics
|
476
|
+
for key, value in metrics.items():
|
477
|
+
if key not in self.best_val_metrics:
|
478
|
+
self.best_val_metrics[key] = value
|
479
|
+
else:
|
480
|
+
self.best_val_metrics[key] = min(self.best_val_metrics[key], value)
|
481
|
+
# Log the best validation metrics
|
482
|
+
if getattr(self, 'wandb', None) is not None and self.wandb:
|
473
483
|
# Log the metrics
|
474
484
|
for key, value in metrics.items():
|
475
485
|
if isinstance(value, jnp.ndarray):
|
@@ -477,7 +487,10 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
477
487
|
self.wandb.log({
|
478
488
|
f"val/{key}": value,
|
479
489
|
}, step=current_step)
|
480
|
-
|
490
|
+
|
491
|
+
|
492
|
+
# Close validation dataset iterator
|
493
|
+
del val_ds
|
481
494
|
except StopIteration:
|
482
495
|
print(f"Validation dataset exhausted for process index {process_index}")
|
483
496
|
except Exception as e:
|
@@ -602,9 +615,13 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
602
615
|
"""
|
603
616
|
if self.wandb is None:
|
604
617
|
raise ValueError("Wandb is not initialized. Cannot get best runs.")
|
605
|
-
|
618
|
+
import wandb
|
606
619
|
# Get the sweep runs
|
607
|
-
runs =
|
620
|
+
runs = [i for i in wandb.Api().runs(path=f"{self.wandb.entity}/{self.wandb.project}", filters={"config.dataset.name": self.wandb.config['dataset']['name']})]
|
621
|
+
if not runs:
|
622
|
+
raise ValueError("No runs found in wandb.")
|
623
|
+
print(f"Getting best runs from wandb {self.wandb.id}...")
|
624
|
+
runs = sorted(runs, key=lambda x: x.summary.get(metric, float('inf')))
|
608
625
|
best_runs = runs[:top_k]
|
609
626
|
lower_bound = best_runs[-1].summary.get(metric, float('inf'))
|
610
627
|
upper_bound = best_runs[0].summary.get(metric, float('inf'))
|
@@ -636,6 +653,8 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
636
653
|
# Check if current run is one of the best
|
637
654
|
if metric == "train/best_loss":
|
638
655
|
current_run_metric = self.best_loss
|
656
|
+
elif metric in self.best_val_metrics:
|
657
|
+
current_run_metric = self.best_val_metrics[metric]
|
639
658
|
else:
|
640
659
|
current_run_metric = self.wandb.summary.get(metric, float('inf') if is_lower_better else float('-inf'))
|
641
660
|
|
@@ -653,7 +672,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
653
672
|
if self.wandb is not None:
|
654
673
|
checkpoint = get_latest_checkpoint(self.checkpoint_path())
|
655
674
|
try:
|
656
|
-
is_good, is_best = self.__compare_run_against_best__(top_k=5, metric=
|
675
|
+
is_good, is_best = self.__compare_run_against_best__(top_k=5, metric=self.best_tracker_metric, from_sweeps=hasattr(self, "wandb_sweep"))
|
657
676
|
if is_good:
|
658
677
|
# Push to registry with appropriate aliases
|
659
678
|
aliases = []
|