flaxdiff 0.2.7__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,15 +10,19 @@ from flaxdiff.models.simple_unet import FourierEmbedding, TimeProjection, ConvLa
10
10
  import einops
11
11
  from flax.typing import Dtype, PrecisionLike
12
12
  from functools import partial
13
- from .common import hilbert_indices, inverse_permutation
13
+ from .hilbert import hilbert_indices, inverse_permutation, hilbert_patchify, hilbert_unpatchify
14
+
14
15
 
15
16
  def unpatchify(x, channels=3):
16
17
  patch_size = int((x.shape[2] // channels) ** 0.5)
17
18
  h = w = int(x.shape[1] ** .5)
18
- assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2], f"Invalid shape: {x.shape}, should be {h*w}, {patch_size**2*channels}"
19
- x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B (h p1) (w p2) C', h=h, p1=patch_size, p2=patch_size)
19
+ assert h * w == x.shape[1] and patch_size ** 2 * \
20
+ channels == x.shape[2], f"Invalid shape: {x.shape}, should be {h*w}, {patch_size**2*channels}"
21
+ x = einops.rearrange(
22
+ x, 'B (h w) (p1 p2 C) -> B (h p1) (w p2) C', h=h, p1=patch_size, p2=patch_size)
20
23
  return x
21
24
 
25
+
22
26
  class PatchEmbedding(nn.Module):
23
27
  patch_size: int
24
28
  embedding_dim: int
@@ -29,15 +33,16 @@ class PatchEmbedding(nn.Module):
29
33
  def __call__(self, x):
30
34
  batch, height, width, channels = x.shape
31
35
  assert height % self.patch_size == 0 and width % self.patch_size == 0, "Image dimensions must be divisible by patch size"
32
-
33
- x = nn.Conv(features=self.embedding_dim,
34
- kernel_size=(self.patch_size, self.patch_size),
36
+
37
+ x = nn.Conv(features=self.embedding_dim,
38
+ kernel_size=(self.patch_size, self.patch_size),
35
39
  strides=(self.patch_size, self.patch_size),
36
40
  dtype=self.dtype,
37
41
  precision=self.precision)(x)
38
42
  x = jnp.reshape(x, (batch, -1, self.embedding_dim))
39
43
  return x
40
44
 
45
+
41
46
  class PositionalEncoding(nn.Module):
42
47
  max_len: int
43
48
  embedding_dim: int
@@ -49,138 +54,294 @@ class PositionalEncoding(nn.Module):
49
54
  (1, self.max_len, self.embedding_dim))
50
55
  return x + pe[:, :x.shape[1], :]
51
56
 
57
+
52
58
  class UViT(nn.Module):
53
- output_channels:int=3
59
+ output_channels: int = 3
54
60
  patch_size: int = 16
55
- emb_features:int=768
56
- num_layers: int = 12
61
+ emb_features: int = 768
62
+ num_layers: int = 12 # Should be even for U-Net structure
57
63
  num_heads: int = 12
58
- dropout_rate: float = 0.1
59
- use_projection: bool = False
60
- use_flash_attention: bool = False
64
+ dropout_rate: float = 0.1 # Dropout is often 0 in diffusion models
65
+ use_projection: bool = False # In TransformerBlock MLP
66
+ use_flash_attention: bool = False # Passed to TransformerBlock
67
+ # Passed to TransformerBlock (likely False for UViT)
61
68
  use_self_and_cross: bool = False
62
- force_fp32_for_softmax: bool = True
63
- activation:Callable = jax.nn.swish
64
- norm_groups:int=8
65
- dtype: Optional[Dtype] = None
69
+ force_fp32_for_softmax: bool = True # Passed to TransformerBlock
70
+ # Used in final convs if add_residualblock_output
71
+ activation: Callable = jax.nn.swish
72
+ norm_groups: int = 8
73
+ dtype: Optional[Dtype] = None # e.g., jnp.float32 or jnp.bfloat16
66
74
  precision: PrecisionLike = None
67
75
  add_residualblock_output: bool = False
68
- norm_inputs: bool = False
69
- explicitly_add_residual: bool = True
70
- norm_epsilon: float = 1e-4 # Added epsilon parameter, increased default
71
- use_hilbert: bool = False # Toggle Hilbert patch reorder
76
+ norm_inputs: bool = False # Passed to TransformerBlock
77
+ explicitly_add_residual: bool = True # Passed to TransformerBlock
78
+ norm_epsilon: float = 1e-5 # Adjusted default
79
+ use_hilbert: bool = False # Toggle Hilbert patch reorder
80
+ use_remat: bool = False # Add flag to use remat
72
81
 
73
82
  def setup(self):
83
+ assert self.num_layers % 2 == 0, "num_layers must be even for U-Net structure"
84
+ half_layers = self.num_layers // 2
85
+
86
+ # --- Norm Layer ---
74
87
  if self.norm_groups > 0:
75
- self.norm = partial(nn.GroupNorm, self.norm_groups, epsilon=self.norm_epsilon)
88
+ # GroupNorm needs features arg, which varies. Define partial here, apply in __call__?
89
+ # Or maybe use LayerNorm/RMSNorm consistently? Let's use LayerNorm for simplicity here.
90
+ # If GroupNorm is essential, it needs careful handling with changing feature sizes.
91
+ # self.norm_factory = partial(nn.GroupNorm, self.norm_groups, epsilon=self.norm_epsilon, dtype=self.dtype)
92
+ print(f"Warning: norm_groups > 0 not fully supported with standard LayerNorm fallback in UViT setup. Using LayerNorm.")
93
+ self.norm_factory = partial(
94
+ nn.LayerNorm, epsilon=self.norm_epsilon, dtype=self.dtype)
95
+ else:
96
+ # Use LayerNorm or RMSNorm for sequence normalization
97
+ # self.norm_factory = partial(nn.RMSNorm, epsilon=self.norm_epsilon, dtype=self.dtype)
98
+ self.norm_factory = partial(
99
+ nn.LayerNorm, epsilon=self.norm_epsilon, dtype=self.dtype)
100
+
101
+ # --- Input Path ---
102
+ self.patch_embed = PatchEmbedding(
103
+ patch_size=self.patch_size,
104
+ embedding_dim=self.emb_features,
105
+ dtype=self.dtype,
106
+ precision=self.precision,
107
+ name="patch_embed"
108
+ )
109
+ if self.use_hilbert:
110
+ # Projection layer needed after raw Hilbert patches
111
+ self.hilbert_proj = nn.Dense(
112
+ features=self.emb_features,
113
+ dtype=self.dtype,
114
+ precision=self.precision,
115
+ name="hilbert_projection"
116
+ )
117
+
118
+ # Positional encoding (learned) - applied only to patch tokens
119
+ # Max length needs to accommodate max possible patches
120
+ # Example: 512x512 image, patch 16 -> (512/16)^2 = 1024 patches
121
+ # Estimate max patches, adjust if needed
122
+ max_patches = (512 // self.patch_size)**2
123
+ self.pos_encoding = self.param('pos_encoding',
124
+ # Standard init for ViT pos embeds
125
+ jax.nn.initializers.normal(stddev=0.02),
126
+ (1, max_patches, self.emb_features))
127
+
128
+ # --- Conditioning ---
129
+ self.time_embed = nn.Sequential([
130
+ FourierEmbedding(features=self.emb_features),
131
+ TimeProjection(features=self.emb_features)
132
+ ], name="time_embed")
133
+
134
+ # Text projection
135
+ self.text_proj = nn.DenseGeneral(
136
+ features=self.emb_features,
137
+ dtype=self.dtype,
138
+ precision=self.precision,
139
+ name="text_proj"
140
+ )
141
+
142
+ # --- Transformer Blocks ---
143
+ BlockClass = TransformerBlock
144
+
145
+ self.down_blocks = [
146
+ BlockClass(
147
+ heads=self.num_heads,
148
+ dim_head=self.emb_features // self.num_heads,
149
+ dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
150
+ use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross,
151
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
152
+ only_pure_attention=False, norm_inputs=self.norm_inputs,
153
+ explicitly_add_residual=self.explicitly_add_residual,
154
+ norm_epsilon=self.norm_epsilon,
155
+ name=f"down_block_{i}"
156
+ ) for i in range(half_layers)
157
+ ]
158
+
159
+ self.mid_block = BlockClass(
160
+ heads=self.num_heads,
161
+ dim_head=self.emb_features // self.num_heads,
162
+ dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
163
+ use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross,
164
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
165
+ only_pure_attention=False, norm_inputs=self.norm_inputs,
166
+ explicitly_add_residual=self.explicitly_add_residual,
167
+ norm_epsilon=self.norm_epsilon,
168
+ name="mid_block"
169
+ )
170
+
171
+ self.up_dense = [
172
+ nn.DenseGeneral( # Project concatenated skip + up_path features back to emb_features
173
+ features=self.emb_features,
174
+ dtype=self.dtype,
175
+ precision=self.precision,
176
+ name=f"up_dense_{i}"
177
+ ) for i in range(half_layers)
178
+ ]
179
+ self.up_blocks = [
180
+ BlockClass(
181
+ heads=self.num_heads,
182
+ dim_head=self.emb_features // self.num_heads,
183
+ dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
184
+ use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross,
185
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
186
+ only_pure_attention=False, norm_inputs=self.norm_inputs,
187
+ explicitly_add_residual=self.explicitly_add_residual,
188
+ norm_epsilon=self.norm_epsilon,
189
+ name=f"up_block_{i}"
190
+ ) for i in range(half_layers)
191
+ ]
192
+
193
+ # --- Output Path ---
194
+ self.final_norm = self.norm_factory(name="final_norm") # Use factory
195
+
196
+ patch_dim = self.patch_size ** 2 * self.output_channels
197
+ self.final_proj = nn.Dense(
198
+ features=patch_dim,
199
+ dtype=self.dtype, # Keep model dtype for projection
200
+ precision=self.precision,
201
+ kernel_init=nn.initializers.zeros, # Zero init final layer
202
+ name="final_proj"
203
+ )
204
+
205
+ if self.add_residualblock_output:
206
+ # Define these layers only if needed
207
+ self.final_conv1 = ConvLayer(
208
+ "conv",
209
+ features=64, kernel_size=(3, 3), strides=(1, 1),
210
+ dtype=self.dtype, precision=self.precision, name="final_conv1"
211
+ )
212
+ self.final_norm_conv = self.norm_factory(
213
+ name="final_norm_conv") # Use factory
214
+ self.final_conv2 = ConvLayer(
215
+ "conv",
216
+ features=self.output_channels, kernel_size=(3, 3), strides=(1, 1),
217
+ dtype=jnp.float32, # Often good to have final conv output float32
218
+ precision=self.precision, name="final_conv2"
219
+ )
76
220
  else:
77
- self.norm = partial(nn.RMSNorm, epsilon=self.norm_epsilon)
78
-
221
+ # Final conv to map features to output channels directly after unpatchify
222
+ self.final_conv_direct = ConvLayer(
223
+ "conv",
224
+ # Use 1x1 conv
225
+ features=self.output_channels, kernel_size=(1, 1), strides=(1, 1),
226
+ dtype=jnp.float32, # Output float32
227
+ precision=self.precision, name="final_conv_direct"
228
+ )
229
+
79
230
  @nn.compact
80
231
  def __call__(self, x, temb, textcontext=None):
81
- # Time embedding
82
- temb = FourierEmbedding(features=self.emb_features)(temb)
83
- temb = TimeProjection(features=self.emb_features)(temb)
84
-
85
- original_img = x
232
+ original_img = x # Keep original for potential residual connection
86
233
  B, H, W, C = original_img.shape
87
234
  H_P = H // self.patch_size
88
235
  W_P = W // self.patch_size
236
+ num_patches = H_P * W_P
237
+ assert H % self.patch_size == 0 and W % self.patch_size == 0, "Image dimensions must be divisible by patch size"
89
238
 
90
- # Patch embedding
91
- x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features,
92
- dtype=self.dtype, precision=self.precision)(x)
93
- num_patches = x.shape[1]
94
-
95
- # Optional Hilbert reorder
239
+ # --- Patch Embedding ---
240
+ hilbert_inv_idx = None
96
241
  if self.use_hilbert:
242
+ # Use hilbert_patchify to get raw patches and inverse index
243
+ patches_raw, hilbert_inv_idx_calc = hilbert_patchify(
244
+ x, self.patch_size) # Shape [B, S, P*P*C]
245
+ # Project raw patches
246
+ # Shape [B, S, emb_features]
247
+ x_patches = self.hilbert_proj(patches_raw)
248
+ # Calculate inverse permutation (needs total_size)
97
249
  idx = hilbert_indices(H_P, W_P)
98
- inv_idx = inverse_permutation(idx)
99
- x = x[:, idx, :]
100
-
101
- context_emb = nn.DenseGeneral(features=self.emb_features,
102
- dtype=self.dtype, precision=self.precision)(textcontext)
103
- num_text_tokens = textcontext.shape[1]
104
-
105
- # Add time embedding
106
- temb = jnp.expand_dims(temb, axis=1)
107
- x = jnp.concatenate([x, temb, context_emb], axis=1)
108
-
109
- # Add positional encoding
110
- x = PositionalEncoding(max_len=x.shape[1], embedding_dim=self.emb_features)(x)
111
-
250
+ hilbert_inv_idx = inverse_permutation(
251
+ idx, total_size=num_patches) # Corrected call
252
+ # Apply Hilbert reordering *after* projection
253
+ x_patches = x_patches[:, idx, :]
254
+ else:
255
+ # Standard patch embedding
256
+ # Shape: [B, num_patches, emb_features]
257
+ x_patches = self.patch_embed(x)
258
+
259
+ # --- Positional Encoding ---
260
+ # Add positional encoding only to patch tokens
261
+ assert num_patches <= self.pos_encoding.shape[
262
+ 1], f"Number of patches {num_patches} exceeds max_len {self.pos_encoding.shape[1]} in positional encoding"
263
+ x_patches = x_patches + self.pos_encoding[:, :num_patches, :]
264
+
265
+ # --- Conditioning Tokens ---
266
+ # Time embedding: [B, D] -> [B, 1, D]
267
+ time_token = self.time_embed(temb.astype(
268
+ jnp.float32)) # Ensure input is float32
269
+ time_token = jnp.expand_dims(time_token.astype(
270
+ self.dtype), axis=1) # Cast back and add seq dim
271
+
272
+ # Text embedding: [B, S_text, D_in] -> [B, S_text, D]
273
+ if textcontext is not None:
274
+ text_tokens = self.text_proj(
275
+ textcontext.astype(self.dtype)) # Cast context
276
+ num_text_tokens = text_tokens.shape[1]
277
+ # Concatenate: [Patches+Pos, Time, Text]
278
+ x = jnp.concatenate([x_patches, time_token, text_tokens], axis=1)
279
+ else:
280
+ # Concatenate: [Patches+Pos, Time]
281
+ num_text_tokens = 0
282
+ x = jnp.concatenate([x_patches, time_token], axis=1)
283
+
284
+ # --- U-Net Transformer ---
112
285
  skips = []
113
- # In blocks
286
+ # Down blocks (Encoder)
114
287
  for i in range(self.num_layers // 2):
115
- x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
116
- dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
117
- use_flash_attention=self.use_flash_attention, use_self_and_cross=False, force_fp32_for_softmax=self.force_fp32_for_softmax,
118
- only_pure_attention=False,
119
- norm_inputs=self.norm_inputs,
120
- explicitly_add_residual=self.explicitly_add_residual,
121
- norm_epsilon=self.norm_epsilon, # Pass epsilon
122
- )(x)
123
- skips.append(x)
124
-
288
+ x = self.down_blocks[i](x) # Pass full sequence (patches+cond)
289
+ skips.append(x) # Store output for skip connection
290
+
125
291
  # Middle block
126
- x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
127
- dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
128
- use_flash_attention=self.use_flash_attention, use_self_and_cross=False, force_fp32_for_softmax=self.force_fp32_for_softmax,
129
- only_pure_attention=False,
130
- norm_inputs=self.norm_inputs,
131
- explicitly_add_residual=self.explicitly_add_residual,
132
- norm_epsilon=self.norm_epsilon, # Pass epsilon
133
- )(x)
134
-
135
- # Out blocks
292
+ x = self.mid_block(x)
293
+
294
+ # Up blocks (Decoder)
136
295
  for i in range(self.num_layers // 2):
137
- x = jnp.concatenate([x, skips.pop()], axis=-1)
138
- x = nn.DenseGeneral(features=self.emb_features,
139
- dtype=self.dtype, precision=self.precision)(x)
140
- x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
141
- dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
142
- use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
143
- only_pure_attention=False,
144
- norm_inputs=self.norm_inputs,
145
- explicitly_add_residual=self.explicitly_add_residual,
146
- norm_epsilon=self.norm_epsilon, # Pass epsilon
147
- )(x)
148
-
149
- x = self.norm()(x) # Uses norm_epsilon defined in setup
150
-
151
- patch_dim = self.patch_size ** 2 * self.output_channels
152
- x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision)(x)
153
- # If Hilbert, restore original patch order
154
- if self.use_hilbert:
155
- x = x[:, inv_idx, :]
296
+ skip_conn = skips.pop()
297
+ # Concatenate along feature dimension
298
+ x = jnp.concatenate([x, skip_conn], axis=-1)
299
+ # Project back to emb_features
300
+ x = self.up_dense[i](x)
301
+ # Apply transformer block
302
+ x = self.up_blocks[i](x)
303
+
304
+ # --- Output Processing ---
305
+ # Normalize before final projection
306
+ x = self.final_norm(x) # Apply norm factory instance
307
+
156
308
  # Extract only the image patch tokens (first num_patches tokens)
157
- x = x[:, :num_patches, :]
158
- x = unpatchify(x, channels=self.output_channels)
159
-
309
+ # Conditioning tokens (time, text) are discarded here
310
+ x_patches_out = x[:, :num_patches, :]
311
+
312
+ # Project to patch pixel dimensions
313
+ # Shape: [B, num_patches, patch_dim]
314
+ x_patches_out = self.final_proj(x_patches_out)
315
+
316
+ # --- Unpatchify ---
317
+ if self.use_hilbert:
318
+ # Restore Hilbert order to row-major order and then to image
319
+ assert hilbert_inv_idx is not None, "Hilbert inverse index missing"
320
+ x_image = hilbert_unpatchify(
321
+ x_patches_out, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
322
+ else:
323
+ # Standard unpatchify
324
+ # Shape: [B, H, W, C_out]
325
+ x_image = unpatchify(x_patches_out, channels=self.output_channels)
326
+
327
+ # --- Final Convolutions ---
160
328
  if self.add_residualblock_output:
161
- # Concatenate the original image
162
- x = jnp.concatenate([original_img, x], axis=-1)
163
-
164
- x = ConvLayer(
165
- "conv",
166
- features=64,
167
- kernel_size=(3, 3),
168
- strides=(1, 1),
169
- # activation=jax.nn.mish
170
- dtype=self.dtype,
171
- precision=self.precision
172
- )(x)
173
-
174
- x = self.norm()(x)
175
- x = self.activation(x)
176
-
177
- x = ConvLayer(
178
- "conv",
179
- features=self.output_channels,
180
- kernel_size=(3, 3),
181
- strides=(1, 1),
182
- # activation=jax.nn.mish
183
- dtype=self.dtype,
184
- precision=self.precision
185
- )(x)
186
- return x
329
+ # Concatenate the original image (ensure dtype matches)
330
+ x_image = jnp.concatenate(
331
+ [original_img.astype(self.dtype), x_image], axis=-1)
332
+
333
+ x_image = self.final_conv1(x_image)
334
+ # Apply norm factory instance
335
+ x_image = self.final_norm_conv(x_image)
336
+ x_image = self.activation(x_image)
337
+ x_image = self.final_conv2(x_image) # Outputs float32
338
+ else:
339
+ # Apply a simple 1x1 conv to map features if needed (unpatchify already gives C_out channels)
340
+ # Or just return x_image if channels match output_channels
341
+ # If unpatchify output channels == self.output_channels, this might be redundant
342
+ # Let's assume unpatchify gives correct channels, but ensure float32
343
+ # x_image = self.final_conv_direct(x_image) # Use 1x1 conv if needed
344
+ pass # Assuming unpatchify output is correct
345
+
346
+ # Ensure final output is float32
347
+ return x_image
@@ -129,6 +129,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
129
129
  frames_per_sample: int = None,
130
130
  wandb_config: Dict[str, Any] = None,
131
131
  eval_metrics: List[EvaluationMetric] = None,
132
+ best_tracker_metric: str = "train/best_loss",
132
133
  **kwargs
133
134
  ):
134
135
  """
@@ -196,6 +197,8 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
196
197
  **kwargs
197
198
  )
198
199
 
200
+ self.best_tracker_metric = best_tracker_metric
201
+
199
202
  # Store video-specific parameters
200
203
  self.frames_per_sample = frames_per_sample
201
204
 
@@ -203,6 +206,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
203
206
  self.conditional_inputs = input_config.conditions
204
207
  # Determine if we're working with video or images
205
208
  self.is_video = self._is_video_data()
209
+ self.best_val_metrics = {}
206
210
 
207
211
  def _is_video_data(self):
208
212
  sample_data_shape = self.input_config.sample_data_shape
@@ -423,7 +427,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
423
427
  process_index = jax.process_index()
424
428
  generate_samples = val_step_fn
425
429
 
426
- val_ds = iter(val_ds()) if val_ds else None
430
+ val_ds = iter(val_ds) if val_ds else None
427
431
  # Evaluation step
428
432
  try:
429
433
  metrics = {metric.name: [] for metric in self.eval_metrics} if self.eval_metrics else {}
@@ -465,11 +469,17 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
465
469
  else: # [B,H,W,C] - Image data
466
470
  self._log_image_samples(samples, current_step)
467
471
 
468
- if getattr(self, 'wandb', None) is not None and self.wandb:
469
- # metrics is a dict of metrics
470
- if metrics and type(metrics) == dict:
471
- # Flatten the metrics
472
- metrics = {k: np.mean(v) for k, v in metrics.items()}
472
+ # Flatten the metrics
473
+ if metrics:
474
+ metrics = {k: np.mean(v) for k, v in metrics.items()}
475
+ # Update the best validation metrics
476
+ for key, value in metrics.items():
477
+ if key not in self.best_val_metrics:
478
+ self.best_val_metrics[key] = value
479
+ else:
480
+ self.best_val_metrics[key] = min(self.best_val_metrics[key], value)
481
+ # Log the best validation metrics
482
+ if getattr(self, 'wandb', None) is not None and self.wandb:
473
483
  # Log the metrics
474
484
  for key, value in metrics.items():
475
485
  if isinstance(value, jnp.ndarray):
@@ -477,7 +487,10 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
477
487
  self.wandb.log({
478
488
  f"val/{key}": value,
479
489
  }, step=current_step)
480
-
490
+
491
+
492
+ # Close validation dataset iterator
493
+ del val_ds
481
494
  except StopIteration:
482
495
  print(f"Validation dataset exhausted for process index {process_index}")
483
496
  except Exception as e:
@@ -602,9 +615,13 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
602
615
  """
603
616
  if self.wandb is None:
604
617
  raise ValueError("Wandb is not initialized. Cannot get best runs.")
605
-
618
+ import wandb
606
619
  # Get the sweep runs
607
- runs = sorted(self.wandb.runs, key=lambda x: x.summary.get(metric, float('inf')))
620
+ runs = [i for i in wandb.Api().runs(path=f"{self.wandb.entity}/{self.wandb.project}", filters={"config.dataset.name": self.wandb.config['dataset']['name']})]
621
+ if not runs:
622
+ raise ValueError("No runs found in wandb.")
623
+ print(f"Getting best runs from wandb {self.wandb.id}...")
624
+ runs = sorted(runs, key=lambda x: x.summary.get(metric, float('inf')))
608
625
  best_runs = runs[:top_k]
609
626
  lower_bound = best_runs[-1].summary.get(metric, float('inf'))
610
627
  upper_bound = best_runs[0].summary.get(metric, float('inf'))
@@ -636,6 +653,8 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
636
653
  # Check if current run is one of the best
637
654
  if metric == "train/best_loss":
638
655
  current_run_metric = self.best_loss
656
+ elif metric in self.best_val_metrics:
657
+ current_run_metric = self.best_val_metrics[metric]
639
658
  else:
640
659
  current_run_metric = self.wandb.summary.get(metric, float('inf') if is_lower_better else float('-inf'))
641
660
 
@@ -653,7 +672,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
653
672
  if self.wandb is not None:
654
673
  checkpoint = get_latest_checkpoint(self.checkpoint_path())
655
674
  try:
656
- is_good, is_best = self.__compare_run_against_best__(top_k=5, metric="train/best_loss", from_sweeps=hasattr(self, "wandb_sweep"))
675
+ is_good, is_best = self.__compare_run_against_best__(top_k=5, metric=self.best_tracker_metric, from_sweeps=hasattr(self, "wandb_sweep"))
657
676
  if is_good:
658
677
  # Push to registry with appropriate aliases
659
678
  aliases = []