flaxdiff 0.2.9__py3-none-any.whl → 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -121,8 +121,8 @@ def parse_config(config, overrides=None):
121
121
  'uvit': UViT,
122
122
  'diffusers_unet_simple': FlaxUNet2DConditionModel,
123
123
  'simple_dit': SimpleDiT,
124
- 'simple_uvit': SimpleUDiT,
125
124
  'simple_mmdit': SimpleMMDiT,
125
+ 'simple_udit': SimpleUDiT,
126
126
  }
127
127
 
128
128
  # Map all the leaves of the model config, converting strings to appropriate types
@@ -427,7 +427,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
427
427
  process_index = jax.process_index()
428
428
  generate_samples = val_step_fn
429
429
 
430
- val_ds = iter(val_ds) if val_ds else None
430
+ val_ds = iter(val_ds()) if val_ds else None
431
431
  print(f"Validation loop started for process index {process_index} with {global_device_count} devices.")
432
432
  # Evaluation step
433
433
  try:
@@ -475,10 +475,11 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
475
475
  metrics = {k: np.mean(v) for k, v in metrics.items()}
476
476
  # Update the best validation metrics
477
477
  for key, value in metrics.items():
478
- if key not in self.best_val_metrics:
479
- self.best_val_metrics[key] = value
478
+ final_key = f"val/{key}"
479
+ if final_key not in self.best_val_metrics:
480
+ self.best_val_metrics[final_key] = value
480
481
  else:
481
- self.best_val_metrics[key] = min(self.best_val_metrics[key], value)
482
+ self.best_val_metrics[final_key] = min(self.best_val_metrics[final_key], value)
482
483
  # Log the best validation metrics
483
484
  if getattr(self, 'wandb', None) is not None and self.wandb:
484
485
  # Log the metrics
@@ -488,6 +489,13 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
488
489
  self.wandb.log({
489
490
  f"val/{key}": value,
490
491
  }, step=current_step)
492
+ # Log the best validation metrics
493
+ for key, value in self.best_val_metrics.items():
494
+ if isinstance(value, jnp.ndarray):
495
+ value = np.array(value)
496
+ self.wandb.log({
497
+ f"best_{key}": value,
498
+ }, step=current_step)
491
499
  print(f"Validation metrics for process index {process_index}: {metrics}")
492
500
 
493
501
  # Close validation dataset iterator
@@ -622,10 +630,10 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
622
630
  if not runs:
623
631
  raise ValueError("No runs found in wandb.")
624
632
  print(f"Getting best runs from wandb {self.wandb.id}...")
625
- runs = sorted(runs, key=lambda x: x.summary.get(metric, float('inf')))
633
+ runs = sorted(runs, key=lambda x: x.summary.get(f"best_{metric}", float('inf')))
626
634
  best_runs = runs[:top_k]
627
- lower_bound = best_runs[-1].summary.get(metric, float('inf'))
628
- upper_bound = best_runs[0].summary.get(metric, float('inf'))
635
+ lower_bound = best_runs[-1].summary.get(f"best_{metric}", float('inf'))
636
+ upper_bound = best_runs[0].summary.get(f"best_{metric}", float('inf'))
629
637
  print(f"Best runs from wandb {self.wandb.id}:")
630
638
  for run in best_runs:
631
639
  print(f"\t\tRun ID: {run.id}, Metric: {run.summary.get(metric, float('inf'))}")
@@ -649,19 +657,21 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
649
657
  best_runs, bounds = self.__get_best_general_runs__(metric=metric, top_k=top_k)
650
658
 
651
659
  # Determine if lower or higher values are better (for loss, lower is better)
652
- is_lower_better = "loss" in metric.lower()
660
+ is_lower_better = True
653
661
 
654
662
  # Check if current run is one of the best
655
663
  if metric == "train/best_loss":
656
664
  current_run_metric = self.best_loss
657
665
  elif metric in self.best_val_metrics:
666
+ print(f"Fetching best validation metric {metric} from local")
658
667
  current_run_metric = self.best_val_metrics[metric]
659
668
  else:
660
669
  current_run_metric = self.wandb.summary.get(metric, float('inf') if is_lower_better else float('-inf'))
661
-
670
+
671
+ print(f"Current run {self.wandb.id} metric: {current_run_metric}, Best bounds: {bounds}")
662
672
  # Check based on bounds
663
673
  if (is_lower_better and current_run_metric < bounds[1]) or (not is_lower_better and current_run_metric > bounds[0]):
664
- print(f"Current run {self.wandb.id} meets performance criteria.")
674
+ print(f"Current run {self.wandb.id} meets performance criteria. Current metric: {current_run_metric}, Best bounds: {bounds}")
665
675
  is_best = (is_lower_better and current_run_metric < bounds[0]) or (not is_lower_better and current_run_metric > bounds[1])
666
676
  return True, is_best
667
677
 
@@ -600,7 +600,7 @@ class SimpleTrainer:
600
600
 
601
601
  def fit(self, data, train_steps_per_epoch, epochs, train_step_args={}, val_steps_per_epoch=5, validation_step_args={}):
602
602
  train_ds = iter(data['train']())
603
- val_ds = data.get('val', data.get('test', None))()
603
+ val_ds = data.get('val', data.get('test', None))
604
604
  train_step = self._define_train_step(**train_step_args)
605
605
  val_step = self._define_validation_step(**validation_step_args)
606
606
  train_state = self.state
@@ -642,6 +642,19 @@ class SimpleTrainer:
642
642
  self.rngstate = rng_state
643
643
  total_time = end_time - start_time
644
644
  avg_time_per_step = total_time / train_steps_per_epoch
645
+
646
+ if val_steps_per_epoch > 0:
647
+ print(f"Validation started for process index {process_index}")
648
+ # Validation step
649
+ self.validation_loop(
650
+ train_state,
651
+ val_step,
652
+ val_ds,
653
+ val_steps_per_epoch,
654
+ current_step,
655
+ )
656
+ print(colored(f"Validation done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
657
+
645
658
  avg_loss = epoch_loss / train_steps_per_epoch
646
659
  if avg_loss < self.best_loss:
647
660
  self.best_loss = avg_loss
@@ -659,17 +672,6 @@ class SimpleTrainer:
659
672
  }, step=current_step)
660
673
  print(colored(f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}", 'green'))
661
674
 
662
- if val_steps_per_epoch > 0:
663
- print(f"Validation started for process index {process_index}")
664
- # Validation step
665
- self.validation_loop(
666
- train_state,
667
- val_step,
668
- val_ds,
669
- val_steps_per_epoch,
670
- current_step,
671
- )
672
- print(colored(f"Validation done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
673
675
 
674
676
  self.save(epochs)#
675
677
  return self.state
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.2.9
3
+ Version: 0.2.10
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -15,7 +15,7 @@ flaxdiff/data/sources/videos.py,sha256=NkxwEruNpAwDCM53q4WurQ802gSjQMOqjNLxYOqjo
15
15
  flaxdiff/data/sources/voxceleb2.py,sha256=BoKfat_hsw6ObDyyaiQmPbBzuFiqgCGlgAZmf-t5Iz8,18621
16
16
  flaxdiff/inference/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
17
  flaxdiff/inference/pipeline.py,sha256=8S30FAlXEjvrDd87H-qdD6biySQZ3cJUflU8gdmPxig,9223
18
- flaxdiff/inference/utils.py,sha256=JEBZYSgj-0DLJTV-TNmIAllAqqVJMn0KfryHwFO-MFs,12606
18
+ flaxdiff/inference/utils.py,sha256=Dh0KawgvQrZxyqN_9wbsb7gUyvPRendwb-YtAU6zIBE,12606
19
19
  flaxdiff/inputs/__init__.py,sha256=ybPjQsFAf5sqRVZG1sRiOl99EnwpI-NQ8HE3y7UbXmU,7197
20
20
  flaxdiff/inputs/encoders.py,sha256=pjfbx4Rk7bLoE80MOfThZDm6YtsDncRekmn0Bmg_CwI,2963
21
21
  flaxdiff/metrics/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -63,9 +63,9 @@ flaxdiff/schedulers/sqrt.py,sha256=mCd_szmOqF6vqQKiAiEOqV_3eBIPGYrW3VxK0o4rBuo,4
63
63
  flaxdiff/trainer/__init__.py,sha256=xSoierfi26gxfgxlNnwvyyPmuPAJ--5i3mEHxt3S-AE,215
64
64
  flaxdiff/trainer/autoencoder_trainer.py,sha256=2FP2P-k9c0n_k3eT0trkq73dQrHRdBj9ObK1idcyhSw,6996
65
65
  flaxdiff/trainer/diffusion_trainer.py,sha256=reQEVWKTqKAeyCMQ-curPOfSRmBKxKooK8EVtUuorcM,14599
66
- flaxdiff/trainer/general_diffusion_trainer.py,sha256=OtE2spZIBFPpY6q-ijYol5Y-CaP2UHJYIDX3PFBiPtg,29492
67
- flaxdiff/trainer/simple_trainer.py,sha256=Hdltuo3lgF61N04Lxc7L3z6NLveW4_h1ff7_5mu3Wbg,28730
68
- flaxdiff-0.2.9.dist-info/METADATA,sha256=a8btxHRkAZVieuZfTyXgPkJbEG9fZRknEhq2Ti3_7m4,24057
69
- flaxdiff-0.2.9.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
70
- flaxdiff-0.2.9.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
71
- flaxdiff-0.2.9.dist-info/RECORD,,
66
+ flaxdiff/trainer/general_diffusion_trainer.py,sha256=gMo0OOz8EFKGfiqZnDwhVSxtk_IUMGUvyt5TTr_Hk8g,30168
67
+ flaxdiff/trainer/simple_trainer.py,sha256=nXYy9tadteG8N0RovpevPPEs6oeFvbr2gVq7Zot9l78,28754
68
+ flaxdiff-0.2.10.dist-info/METADATA,sha256=xsqksvLSps2a9nNdvZkguWvsC07kX8A3Z26DPTq-tGI,24058
69
+ flaxdiff-0.2.10.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
70
+ flaxdiff-0.2.10.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
71
+ flaxdiff-0.2.10.dist-info/RECORD,,