flaxdiff 0.1.38__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. flaxdiff/data/__init__.py +5 -1
  2. flaxdiff/data/benchmark_decord.py +443 -0
  3. flaxdiff/data/dataloaders.py +608 -0
  4. flaxdiff/data/dataset_map.py +61 -6
  5. flaxdiff/data/online_loader.py +779 -150
  6. flaxdiff/data/sources/audio_utils.py +142 -0
  7. flaxdiff/data/sources/av_example.py +125 -0
  8. flaxdiff/data/sources/av_utils.py +590 -0
  9. flaxdiff/data/sources/base.py +129 -0
  10. flaxdiff/data/sources/images.py +309 -0
  11. flaxdiff/data/sources/utils.py +158 -0
  12. flaxdiff/data/sources/videos.py +250 -0
  13. flaxdiff/data/sources/voxceleb2.py +412 -0
  14. flaxdiff/inference/__init__.py +0 -0
  15. flaxdiff/inference/pipeline.py +260 -0
  16. flaxdiff/inference/utils.py +320 -0
  17. flaxdiff/inputs/__init__.py +173 -0
  18. flaxdiff/inputs/encoders.py +98 -0
  19. flaxdiff/models/__init__.py +2 -1
  20. flaxdiff/models/attention.py +22 -16
  21. flaxdiff/models/autoencoder/autoencoder.py +141 -9
  22. flaxdiff/models/autoencoder/diffusers.py +88 -25
  23. flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
  24. flaxdiff/models/common.py +8 -18
  25. flaxdiff/models/simple_unet.py +6 -17
  26. flaxdiff/models/simple_vit.py +9 -13
  27. flaxdiff/models/unet_3d.py +446 -0
  28. flaxdiff/models/unet_3d_blocks.py +505 -0
  29. flaxdiff/samplers/common.py +358 -96
  30. flaxdiff/samplers/ddim.py +44 -5
  31. flaxdiff/schedulers/karras.py +20 -12
  32. flaxdiff/trainer/__init__.py +2 -1
  33. flaxdiff/trainer/autoencoder_trainer.py +1 -2
  34. flaxdiff/trainer/diffusion_trainer.py +35 -29
  35. flaxdiff/trainer/general_diffusion_trainer.py +583 -0
  36. flaxdiff/trainer/simple_trainer.py +51 -16
  37. flaxdiff/utils.py +128 -57
  38. {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
  39. flaxdiff-0.2.0.dist-info/RECORD +64 -0
  40. {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
  41. flaxdiff/data/datasets.py +0 -169
  42. flaxdiff/data/sources/gcs.py +0 -81
  43. flaxdiff/data/sources/tfds.py +0 -79
  44. flaxdiff/trainer/video_diffusion_trainer.py +0 -62
  45. flaxdiff-0.1.38.dist-info/RECORD +0 -50
  46. {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,583 @@
1
+ import json
2
+ import flax
3
+ from flax import linen as nn
4
+ import jax
5
+ from typing import Callable, List, Dict, Tuple, Union, Any, Sequence, Type, Optional
6
+ from dataclasses import field, dataclass
7
+ import jax.numpy as jnp
8
+ import optax
9
+ import functools
10
+ from jax.sharding import Mesh, PartitionSpec as P
11
+ from jax.experimental.shard_map import shard_map
12
+
13
+ from ..schedulers import NoiseScheduler, get_coeff_shapes_tuple
14
+ from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
15
+ from ..samplers.common import DiffusionSampler
16
+ from ..samplers.ddim import DDIMSampler
17
+
18
+ from flaxdiff.utils import RandomMarkovState, serialize_model, get_latest_checkpoint
19
+ from flaxdiff.inputs import ConditioningEncoder, ConditionalInputConfig, DiffusionInputConfig
20
+
21
+ from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
22
+
23
+ from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
24
+ from flax.training import dynamic_scale as dynamic_scale_lib
25
+
26
+ # Reuse the TrainState from the DiffusionTrainer
27
+ from flaxdiff.trainer.diffusion_trainer import TrainState, DiffusionTrainer
28
+ import shutil
29
+
30
+ def generate_modelname(
31
+ dataset_name: str,
32
+ noise_schedule_name: str,
33
+ architecture_name: str,
34
+ model: nn.Module,
35
+ input_config: DiffusionInputConfig,
36
+ autoencoder: AutoEncoder = None,
37
+ frames_per_sample: int = None,
38
+ ) -> str:
39
+ """
40
+ Generate a model name based on the configuration.
41
+
42
+ Args:
43
+ config: Configuration dictionary.
44
+
45
+ Returns:
46
+ A string representing the model name.
47
+ """
48
+ import hashlib
49
+ import json
50
+
51
+ # Extract key components for the name
52
+
53
+ model_name = f"diffusion-{dataset_name}-res{input_config.sample_data_shape[-2]}"
54
+
55
+ # model_name = f"diffusion-{dataset_name}-res{input_config.sample_data_shape[-2]}-{noise_schedule_name}-{architecture_name}"
56
+
57
+ # if autoencoder is not None:
58
+ # model_name += f"-vae"
59
+
60
+ # if frames_per_sample is not None:
61
+ # model_name += f"-frames_{frames_per_sample}"
62
+
63
+ # model_name += f"-{'.'.join([cond.encoder.key for cond in input_config.conditions])}"
64
+
65
+ # # Create a sorted representation of model config for consistent hashing
66
+ # def sort_dict_recursively(d):
67
+ # if isinstance(d, dict):
68
+ # return {k: sort_dict_recursively(d[k]) for k in sorted(d.keys())}
69
+ # elif isinstance(d, list):
70
+ # return [sort_dict_recursively(v) for v in d]
71
+ # else:
72
+ # return d
73
+
74
+ # # Extract model config and sort it
75
+ # model_config = serialize_model(model)
76
+ # sorted_model_config = sort_dict_recursively(model_config)
77
+
78
+ # # Convert to JSON string with sorted keys for consistent hash
79
+ # try:
80
+ # config_json = json.dumps(sorted_model_config)
81
+ # except TypeError:
82
+ # # Handle non-serializable objects
83
+ # def make_serializable(obj):
84
+ # if isinstance(obj, dict):
85
+ # return {k: make_serializable(v) for k, v in obj.items()}
86
+ # elif isinstance(obj, list):
87
+ # return [make_serializable(v) for v in obj]
88
+ # else:
89
+ # try:
90
+ # # Test if object is JSON serializable
91
+ # json.dumps(obj)
92
+ # return obj
93
+ # except TypeError:
94
+ # return str(obj)
95
+
96
+ # serializable_config = make_serializable(sorted_model_config)
97
+ # config_json = json.dumps(serializable_config)
98
+
99
+ # # Generate a hash of the configuration
100
+ # config_hash = hashlib.md5(config_json.encode('utf-8')).hexdigest()[:8]
101
+
102
+ # # Construct the model name
103
+ # model_name = f"{model_name}-{config_hash}"
104
+ return model_name
105
+
106
+ class GeneralDiffusionTrainer(DiffusionTrainer):
107
+ """
108
+ General trainer for diffusion models supporting both images and videos.
109
+
110
+ Extends DiffusionTrainer to support:
111
+ 1. Both image data (4D tensors: B,H,W,C) and video data (5D tensors: B,T,H,W,C)
112
+ 2. Multiple conditioning inputs
113
+ 3. Various model architectures
114
+ """
115
+
116
+ def __init__(self,
117
+ model: nn.Module,
118
+ optimizer: optax.GradientTransformation,
119
+ noise_schedule: NoiseScheduler,
120
+ input_config: DiffusionInputConfig,
121
+ rngs: jax.random.PRNGKey,
122
+ unconditional_prob: float = 0.12,
123
+ name: str = "GeneralDiffusion",
124
+ model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
125
+ autoencoder: AutoEncoder = None,
126
+ native_resolution: int = None,
127
+ frames_per_sample: int = None,
128
+ wandb_config: Dict[str, Any] = None,
129
+ **kwargs
130
+ ):
131
+ """
132
+ Initialize the general diffusion trainer.
133
+
134
+ Args:
135
+ model: Neural network model
136
+ optimizer: Optimization algorithm
137
+ noise_schedule: Noise scheduler for diffusion process
138
+ input_config: Configuration for input data, including keys, shapes and conditioning inputs
139
+ rngs: Random number generator keys
140
+ unconditional_prob: Probability of training with unconditional samples
141
+ name: Name of this trainer
142
+ model_output_transform: Transform for model predictions
143
+ autoencoder: Optional autoencoder for latent diffusion
144
+ native_resolution: Native resolution of the data
145
+ frames_per_sample: Number of frames per video sample (for video only)
146
+ **kwargs: Additional arguments for parent class
147
+ """
148
+ # Initialize with parent DiffusionTrainer but without encoder parameter
149
+ input_shapes = input_config.get_input_shapes(
150
+ autoencoder=autoencoder,
151
+ )
152
+ self.input_config = input_config
153
+
154
+ if wandb_config is not None:
155
+ # If input_config is not in wandb_config, add it
156
+ if 'input_config' not in wandb_config['config']:
157
+ wandb_config['config']['input_config'] = input_config.serialize()
158
+ # If model is not in wandb_config, add it
159
+ if 'model' not in wandb_config['config']:
160
+ wandb_config['config']['model'] = serialize_model(model)
161
+ if 'autoencoder' not in wandb_config['config'] and autoencoder is not None:
162
+ wandb_config['config']['autoencoder'] = autoencoder.name
163
+ wandb_config['config']['autoencoder_opts'] = json.dumps(autoencoder.serialize())
164
+
165
+ # Generate a model name based on the configuration
166
+ modelname = generate_modelname(
167
+ dataset_name=wandb_config['config']['arguments']['dataset'],
168
+ noise_schedule_name=wandb_config['config']['arguments']['noise_schedule'],
169
+ architecture_name=wandb_config['config']['arguments']['architecture'],
170
+ model=model,
171
+ input_config=input_config,
172
+ autoencoder=autoencoder,
173
+ frames_per_sample=frames_per_sample,
174
+ )
175
+ print("Model name:", modelname)
176
+ self.modelname = modelname
177
+ wandb_config['config']['modelname'] = modelname
178
+
179
+ super().__init__(
180
+ model=model,
181
+ input_shapes=input_shapes,
182
+ optimizer=optimizer,
183
+ noise_schedule=noise_schedule,
184
+ unconditional_prob=unconditional_prob,
185
+ autoencoder=autoencoder,
186
+ model_output_transform=model_output_transform,
187
+ rngs=rngs,
188
+ name=name,
189
+ native_resolution=native_resolution,
190
+ encoder=None, # Don't use the default encoder from the parent class
191
+ wandb_config=wandb_config,
192
+ **kwargs
193
+ )
194
+
195
+ # Store video-specific parameters
196
+ self.frames_per_sample = frames_per_sample
197
+
198
+ # List of conditional inputs
199
+ self.conditional_inputs = input_config.conditions
200
+ # Determine if we're working with video or images
201
+ self.is_video = self._is_video_data()
202
+
203
+ def _is_video_data(self):
204
+ sample_data_shape = self.input_config.sample_data_shape
205
+ return len(sample_data_shape) == 5
206
+
207
+ def _define_train_step(self, batch_size):
208
+ """
209
+ Define the training step function for both image and video diffusion.
210
+ Optimized for efficient sharding and JIT compilation.
211
+ """
212
+ # Access class variables once for JIT optimization
213
+ noise_schedule = self.noise_schedule
214
+ model = self.model
215
+ model_output_transform = self.model_output_transform
216
+ loss_fn = self.loss_fn
217
+ distributed_training = self.distributed_training
218
+ autoencoder = self.autoencoder
219
+ unconditional_prob = self.unconditional_prob
220
+
221
+ input_config = self.input_config
222
+ sample_data_key = input_config.sample_data_key
223
+
224
+ # JIT-optimized function for processing conditional inputs
225
+ # @functools.partial(jax.jit, static_argnums=(2,))
226
+ def process_conditioning(batch, uncond_mask):
227
+ return input_config.process_conditioning(
228
+ batch,
229
+ uncond_mask=uncond_mask,
230
+ )
231
+
232
+ # Main training step function - optimized for JIT compilation and sharding
233
+ def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, local_device_index):
234
+ """Training step optimized for distributed execution."""
235
+ # Random key handling
236
+ rng_state, key_fold = rng_state.get_random_key()
237
+ folded_key = jax.random.fold_in(key_fold, local_device_index.reshape())
238
+ local_rng_state = RandomMarkovState(folded_key)
239
+
240
+ # Extract and normalize data (works for both images and videos)
241
+ data = batch[sample_data_key]
242
+ local_batch_size = data.shape[0]
243
+ data = (jnp.asarray(data, dtype=jnp.float32) - 127.5) / 127.5
244
+
245
+ # Autoencoder step (handles both image and video data)
246
+ if autoencoder is not None:
247
+ local_rng_state, enc_key = local_rng_state.get_random_key()
248
+ data = autoencoder.encode(data, enc_key)
249
+
250
+ # Determine number of unconditional samples per mini batch randomly
251
+ local_rng_state, uncond_key = local_rng_state.get_random_key()
252
+ # Determine unconditional samples
253
+ uncond_mask = jax.random.bernoulli(
254
+ uncond_key,
255
+ shape=(local_batch_size,),
256
+ p=unconditional_prob
257
+ )
258
+
259
+ # Process conditioning
260
+ all_conditional_inputs = process_conditioning(batch, uncond_mask)
261
+
262
+ # Generate diffusion timesteps
263
+ noise_level, local_rng_state = noise_schedule.generate_timesteps(local_batch_size, local_rng_state)
264
+
265
+ # Generate noise
266
+ local_rng_state, noise_key = local_rng_state.get_random_key()
267
+ noise = jax.random.normal(noise_key, shape=data.shape, dtype=jnp.float32)
268
+
269
+ # Forward diffusion process
270
+ rates = noise_schedule.get_rates(noise_level, get_coeff_shapes_tuple(data))
271
+ noisy_data, c_in, expected_output = model_output_transform.forward_diffusion(data, noise, rates)
272
+
273
+ # Loss function
274
+ def model_loss(params):
275
+ # Apply model
276
+ inputs = noise_schedule.transform_inputs(noisy_data * c_in, noise_level)
277
+ preds = model.apply(params, *inputs, *all_conditional_inputs)
278
+
279
+ # Transform predictions and calculate loss
280
+ preds = model_output_transform.pred_transform(noisy_data, preds, rates)
281
+ sample_losses = loss_fn(preds, expected_output)
282
+
283
+ # Apply loss weighting
284
+ weights = noise_schedule.get_weights(noise_level, get_coeff_shapes_tuple(sample_losses))
285
+ weighted_loss = sample_losses * weights
286
+
287
+ return jnp.mean(weighted_loss)
288
+
289
+ # Compute gradients and apply updates
290
+ if train_state.dynamic_scale is not None:
291
+ # Mixed precision training with dynamic scale
292
+ grad_fn = train_state.dynamic_scale.value_and_grad(model_loss, axis_name="data")
293
+ dynamic_scale, is_finite, loss, grads = grad_fn(train_state.params)
294
+
295
+ train_state = train_state.replace(dynamic_scale=dynamic_scale)
296
+ new_state = train_state.apply_gradients(grads=grads)
297
+
298
+ # Handle NaN/Inf gradients
299
+ select_fn = functools.partial(jnp.where, is_finite)
300
+ new_state = new_state.replace(
301
+ opt_state=jax.tree_map(select_fn, new_state.opt_state, train_state.opt_state),
302
+ params=jax.tree_map(select_fn, new_state.params, train_state.params)
303
+ )
304
+ else:
305
+ # Standard gradient computation
306
+ grad_fn = jax.value_and_grad(model_loss)
307
+ loss, grads = grad_fn(train_state.params)
308
+
309
+ if distributed_training:
310
+ grads = jax.lax.pmean(grads, axis_name="data")
311
+
312
+ new_state = train_state.apply_gradients(grads=grads)
313
+
314
+ # Apply EMA update
315
+ new_state = new_state.apply_ema(self.ema_decay)
316
+
317
+ # Average loss across devices if distributed
318
+ if distributed_training:
319
+ loss = jax.lax.pmean(loss, axis_name="data")
320
+
321
+ return new_state, loss, rng_state
322
+
323
+ # Apply sharding for distributed training
324
+ if distributed_training:
325
+ train_step = shard_map(
326
+ train_step,
327
+ mesh=self.mesh,
328
+ in_specs=(P(), P(), P('data'), P('data')),
329
+ out_specs=(P(), P(), P()),
330
+ )
331
+
332
+ # Apply JIT compilation
333
+ train_step = jax.jit(train_step, donate_argnums=(2))
334
+ return train_step
335
+
336
+ def _define_validation_step(self, sampler_class: Type[DiffusionSampler]=DDIMSampler, sampling_noise_schedule: NoiseScheduler=None):
337
+ """
338
+ Define the validation step for both image and video diffusion models.
339
+ """
340
+ # Setup for validation
341
+ model = self.model
342
+ autoencoder = self.autoencoder
343
+ input_config = self.input_config
344
+ conditional_inputs = self.conditional_inputs
345
+ is_video = self.is_video
346
+
347
+ # Get necessary parameters
348
+ image_size = self._get_image_size()
349
+
350
+ # Get sequence length only for video data
351
+ sequence_length = self._get_sequence_length() if is_video else None
352
+
353
+ # Initialize the sampler
354
+ sampler = sampler_class(
355
+ model=model,
356
+ noise_schedule=self.noise_schedule if sampling_noise_schedule is None else sampling_noise_schedule,
357
+ model_output_transform=self.model_output_transform,
358
+ input_config=input_config,
359
+ autoencoder=autoencoder,
360
+ guidance_scale=3.0,
361
+ )
362
+
363
+ def generate_samples(
364
+ val_state: TrainState,
365
+ batch,
366
+ sampler: DiffusionSampler,
367
+ diffusion_steps: int,
368
+ ):
369
+ # Process all conditional inputs
370
+ model_conditioning_inputs = [cond_input(batch) for cond_input in conditional_inputs]
371
+
372
+ # Determine batch size
373
+ batch_size = len(model_conditioning_inputs[0]) if model_conditioning_inputs else 4
374
+
375
+ # Generate samples - works for both images and videos
376
+ return sampler.generate_samples(
377
+ params=val_state.ema_params,
378
+ resolution=image_size,
379
+ num_samples=batch_size,
380
+ sequence_length=sequence_length, # Will be None for images
381
+ diffusion_steps=diffusion_steps,
382
+ start_step=1000,
383
+ end_step=0,
384
+ priors=None,
385
+ model_conditioning_inputs=tuple(model_conditioning_inputs),
386
+ )
387
+
388
+ return sampler, generate_samples
389
+
390
+ def _get_image_size(self):
391
+ """Helper to determine image size from available information."""
392
+ if self.native_resolution is not None:
393
+ return self.native_resolution
394
+
395
+ sample_data_shape = self.input_config.sample_data_shape
396
+ return sample_data_shape[-2] # Assuming [..., H, W, C] format
397
+
398
+ def _get_sequence_length(self):
399
+ """Helper to determine sequence length for video generation."""
400
+ if not self.is_video:
401
+ return None
402
+
403
+ sample_data_shape = self.input_config.sample_data_shape
404
+ return sample_data_shape[1] # Assuming [B,T,H,W,C] format
405
+
406
+ def validation_loop(
407
+ self,
408
+ val_state: SimpleTrainState,
409
+ val_step_fn: Callable,
410
+ val_ds,
411
+ val_steps_per_epoch,
412
+ current_step,
413
+ diffusion_steps=200,
414
+ ):
415
+ """
416
+ Run validation and log samples for both image and video diffusion.
417
+ """
418
+ sampler, generate_samples = val_step_fn
419
+ val_ds = iter(val_ds()) if val_ds else None
420
+
421
+ try:
422
+ # Generate samples
423
+ samples = generate_samples(
424
+ val_state,
425
+ next(val_ds),
426
+ sampler,
427
+ diffusion_steps,
428
+ )
429
+
430
+ # Log samples to wandb
431
+ if getattr(self, 'wandb', None) is not None and self.wandb:
432
+ import numpy as np
433
+
434
+ # Process samples differently based on dimensionality
435
+ if len(samples.shape) == 5: # [B,T,H,W,C] - Video data
436
+ self._log_video_samples(samples, current_step)
437
+ else: # [B,H,W,C] - Image data
438
+ self._log_image_samples(samples, current_step)
439
+
440
+ except Exception as e:
441
+ print("Error in validation loop:", e)
442
+ import traceback
443
+ traceback.print_exc()
444
+
445
+ def _log_video_samples(self, samples, current_step):
446
+ """Helper to log video samples to wandb."""
447
+ import numpy as np
448
+ from wandb import Video as wandbVideo
449
+
450
+ for i in range(samples.shape[0]):
451
+ # Convert to numpy, denormalize and clip
452
+ sample = np.array(samples[i])
453
+ sample = (sample + 1) * 127.5
454
+ sample = np.clip(sample, 0, 255).astype(np.uint8)
455
+
456
+ # Log as video
457
+ self.wandb.log({
458
+ f"video_sample_{i}": wandbVideo(
459
+ sample,
460
+ fps=10,
461
+ caption=f"Video Sample {i} at step {current_step}"
462
+ )
463
+ }, step=current_step)
464
+
465
+ def _log_image_samples(self, samples, current_step):
466
+ """Helper to log image samples to wandb."""
467
+ import numpy as np
468
+ from wandb import Image as wandbImage
469
+
470
+ for i in range(samples.shape[0]):
471
+ # Convert to numpy, denormalize and clip
472
+ sample = np.array(samples[i])
473
+ sample = (sample + 1) * 127.5
474
+ sample = np.clip(sample, 0, 255).astype(np.uint8)
475
+
476
+ # Log as image
477
+ self.wandb.log({
478
+ f"sample_{i}": wandbImage(
479
+ sample,
480
+ caption=f"Sample {i} at step {current_step}"
481
+ )
482
+ }, step=current_step)
483
+
484
+ def push_to_registry(
485
+ self,
486
+ registry_name: str = 'wandb-registry-model',
487
+ ):
488
+ """
489
+ Push the model to wandb registry.
490
+ Args:
491
+ registry_name: Name of the model registry.
492
+ """
493
+ if self.wandb is None:
494
+ raise ValueError("Wandb is not initialized. Cannot push to registry.")
495
+
496
+ modelname = self.modelname
497
+ if hasattr(self, "wandb_sweep"):
498
+ modelname = f"{modelname}-sweep-{self.wandb_sweep.id}"
499
+
500
+ latest_checkpoint_path = get_latest_checkpoint(self.checkpoint_path())
501
+ logged_artifact = self.wandb.log_artifact(
502
+ artifact_or_path=latest_checkpoint_path,
503
+ name=modelname,
504
+ type="model",
505
+ )
506
+
507
+ target_path = f"{registry_name}/{modelname}"
508
+
509
+ self.wandb.link_artifact(
510
+ artifact=logged_artifact,
511
+ target_path=target_path,
512
+ )
513
+ print(f"Model pushed to registry at {target_path}")
514
+ return logged_artifact
515
+
516
+ def __get_best_sweep_runs__(
517
+ self,
518
+ metric: str = "train/best_loss",
519
+ top_k: int = 5,
520
+ ):
521
+ """
522
+ Get the best runs from a wandb sweep.
523
+ Args:
524
+ metric: Metric to sort by.
525
+ top_k: Number of top runs to return.
526
+ """
527
+ if self.wandb is None:
528
+ raise ValueError("Wandb is not initialized. Cannot get best runs.")
529
+
530
+ if not hasattr(self, "wandb_sweep"):
531
+ raise ValueError("Wandb sweep is not initialized. Cannot get best runs.")
532
+
533
+ # Get the sweep runs
534
+ runs = sorted(self.wandb_sweep.runs, key=lambda x: x.summary.get(metric, float('inf')))
535
+ best_runs = runs[:top_k]
536
+ lower_bound = best_runs[-1].summary.get(metric, float('inf'))
537
+ upper_bound = best_runs[0].summary.get(metric, float('inf'))
538
+ print(f"Best runs from sweep {self.wandb_sweep.id}:")
539
+ for run in best_runs:
540
+ print(f"\t\tRun ID: {run.id}, Metric: {run.summary.get(metric, float('inf'))}")
541
+ return best_runs, (min(lower_bound, upper_bound), max(lower_bound, upper_bound))
542
+
543
+ def __compare_run_against_best__(self, top_k=2, metric="train/best_loss"):
544
+ # Get best runs
545
+ best_runs, bounds = self.__get_best_sweep_runs__(metric=metric, top_k=top_k)
546
+
547
+ # Determine if lower or higher values are better (for loss, lower is better)
548
+ is_lower_better = "loss" in metric.lower()
549
+
550
+ # Check if current run is one of the best
551
+ current_run_metric = self.wandb.summary.get(metric, float('inf') if is_lower_better else float('-inf'))
552
+
553
+ # Direct check if current run is in best runs
554
+ for run in best_runs:
555
+ if run.id == self.wandb.id:
556
+ print(f"Current run {self.wandb.id} is one of the best runs.")
557
+ return True
558
+
559
+ # Backup check based on metric value
560
+ if (is_lower_better and current_run_metric < bounds[1]) or (not is_lower_better and current_run_metric > bounds[0]):
561
+ print(f"Current run {self.wandb.id} meets performance criteria.")
562
+ return True
563
+
564
+ return False
565
+
566
+ def save(self, epoch=0, step=0, state=None, rngstate=None):
567
+ super().save(epoch=epoch, step=step, state=state, rngstate=rngstate)
568
+
569
+ if self.wandb is not None and hasattr(self, "wandb_sweep"):
570
+ checkpoint = get_latest_checkpoint(self.checkpoint_path())
571
+ try:
572
+ if self.__compare_run_against_best__(top_k=5, metric="train/best_loss"):
573
+ self.push_to_registry()
574
+ print("Model pushed to registry successfully")
575
+ else:
576
+ print("Current run is not one of the best runs. Not saving model.")
577
+
578
+ # Only delete after successful registry push
579
+ shutil.rmtree(checkpoint, ignore_errors=True)
580
+ print(f"Checkpoint deleted at {checkpoint}")
581
+ except Exception as e:
582
+ print(f"Error during registry operations: {e}")
583
+ print(f"Checkpoint preserved at {checkpoint}")