flaxdiff 0.2.9__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/inference/utils.py +1 -1
- flaxdiff/trainer/general_diffusion_trainer.py +20 -10
- flaxdiff/trainer/simple_trainer.py +14 -12
- {flaxdiff-0.2.9.dist-info → flaxdiff-0.2.10.dist-info}/METADATA +1 -1
- {flaxdiff-0.2.9.dist-info → flaxdiff-0.2.10.dist-info}/RECORD +7 -7
- {flaxdiff-0.2.9.dist-info → flaxdiff-0.2.10.dist-info}/WHEEL +0 -0
- {flaxdiff-0.2.9.dist-info → flaxdiff-0.2.10.dist-info}/top_level.txt +0 -0
flaxdiff/inference/utils.py
CHANGED
@@ -121,8 +121,8 @@ def parse_config(config, overrides=None):
|
|
121
121
|
'uvit': UViT,
|
122
122
|
'diffusers_unet_simple': FlaxUNet2DConditionModel,
|
123
123
|
'simple_dit': SimpleDiT,
|
124
|
-
'simple_uvit': SimpleUDiT,
|
125
124
|
'simple_mmdit': SimpleMMDiT,
|
125
|
+
'simple_udit': SimpleUDiT,
|
126
126
|
}
|
127
127
|
|
128
128
|
# Map all the leaves of the model config, converting strings to appropriate types
|
@@ -427,7 +427,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
427
427
|
process_index = jax.process_index()
|
428
428
|
generate_samples = val_step_fn
|
429
429
|
|
430
|
-
val_ds = iter(val_ds) if val_ds else None
|
430
|
+
val_ds = iter(val_ds()) if val_ds else None
|
431
431
|
print(f"Validation loop started for process index {process_index} with {global_device_count} devices.")
|
432
432
|
# Evaluation step
|
433
433
|
try:
|
@@ -475,10 +475,11 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
475
475
|
metrics = {k: np.mean(v) for k, v in metrics.items()}
|
476
476
|
# Update the best validation metrics
|
477
477
|
for key, value in metrics.items():
|
478
|
-
|
479
|
-
|
478
|
+
final_key = f"val/{key}"
|
479
|
+
if final_key not in self.best_val_metrics:
|
480
|
+
self.best_val_metrics[final_key] = value
|
480
481
|
else:
|
481
|
-
self.best_val_metrics[
|
482
|
+
self.best_val_metrics[final_key] = min(self.best_val_metrics[final_key], value)
|
482
483
|
# Log the best validation metrics
|
483
484
|
if getattr(self, 'wandb', None) is not None and self.wandb:
|
484
485
|
# Log the metrics
|
@@ -488,6 +489,13 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
488
489
|
self.wandb.log({
|
489
490
|
f"val/{key}": value,
|
490
491
|
}, step=current_step)
|
492
|
+
# Log the best validation metrics
|
493
|
+
for key, value in self.best_val_metrics.items():
|
494
|
+
if isinstance(value, jnp.ndarray):
|
495
|
+
value = np.array(value)
|
496
|
+
self.wandb.log({
|
497
|
+
f"best_{key}": value,
|
498
|
+
}, step=current_step)
|
491
499
|
print(f"Validation metrics for process index {process_index}: {metrics}")
|
492
500
|
|
493
501
|
# Close validation dataset iterator
|
@@ -622,10 +630,10 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
622
630
|
if not runs:
|
623
631
|
raise ValueError("No runs found in wandb.")
|
624
632
|
print(f"Getting best runs from wandb {self.wandb.id}...")
|
625
|
-
runs = sorted(runs, key=lambda x: x.summary.get(metric, float('inf')))
|
633
|
+
runs = sorted(runs, key=lambda x: x.summary.get(f"best_{metric}", float('inf')))
|
626
634
|
best_runs = runs[:top_k]
|
627
|
-
lower_bound = best_runs[-1].summary.get(metric, float('inf'))
|
628
|
-
upper_bound = best_runs[0].summary.get(metric, float('inf'))
|
635
|
+
lower_bound = best_runs[-1].summary.get(f"best_{metric}", float('inf'))
|
636
|
+
upper_bound = best_runs[0].summary.get(f"best_{metric}", float('inf'))
|
629
637
|
print(f"Best runs from wandb {self.wandb.id}:")
|
630
638
|
for run in best_runs:
|
631
639
|
print(f"\t\tRun ID: {run.id}, Metric: {run.summary.get(metric, float('inf'))}")
|
@@ -649,19 +657,21 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
649
657
|
best_runs, bounds = self.__get_best_general_runs__(metric=metric, top_k=top_k)
|
650
658
|
|
651
659
|
# Determine if lower or higher values are better (for loss, lower is better)
|
652
|
-
is_lower_better =
|
660
|
+
is_lower_better = True
|
653
661
|
|
654
662
|
# Check if current run is one of the best
|
655
663
|
if metric == "train/best_loss":
|
656
664
|
current_run_metric = self.best_loss
|
657
665
|
elif metric in self.best_val_metrics:
|
666
|
+
print(f"Fetching best validation metric {metric} from local")
|
658
667
|
current_run_metric = self.best_val_metrics[metric]
|
659
668
|
else:
|
660
669
|
current_run_metric = self.wandb.summary.get(metric, float('inf') if is_lower_better else float('-inf'))
|
661
|
-
|
670
|
+
|
671
|
+
print(f"Current run {self.wandb.id} metric: {current_run_metric}, Best bounds: {bounds}")
|
662
672
|
# Check based on bounds
|
663
673
|
if (is_lower_better and current_run_metric < bounds[1]) or (not is_lower_better and current_run_metric > bounds[0]):
|
664
|
-
print(f"Current run {self.wandb.id} meets performance criteria.")
|
674
|
+
print(f"Current run {self.wandb.id} meets performance criteria. Current metric: {current_run_metric}, Best bounds: {bounds}")
|
665
675
|
is_best = (is_lower_better and current_run_metric < bounds[0]) or (not is_lower_better and current_run_metric > bounds[1])
|
666
676
|
return True, is_best
|
667
677
|
|
@@ -600,7 +600,7 @@ class SimpleTrainer:
|
|
600
600
|
|
601
601
|
def fit(self, data, train_steps_per_epoch, epochs, train_step_args={}, val_steps_per_epoch=5, validation_step_args={}):
|
602
602
|
train_ds = iter(data['train']())
|
603
|
-
val_ds = data.get('val', data.get('test', None))
|
603
|
+
val_ds = data.get('val', data.get('test', None))
|
604
604
|
train_step = self._define_train_step(**train_step_args)
|
605
605
|
val_step = self._define_validation_step(**validation_step_args)
|
606
606
|
train_state = self.state
|
@@ -642,6 +642,19 @@ class SimpleTrainer:
|
|
642
642
|
self.rngstate = rng_state
|
643
643
|
total_time = end_time - start_time
|
644
644
|
avg_time_per_step = total_time / train_steps_per_epoch
|
645
|
+
|
646
|
+
if val_steps_per_epoch > 0:
|
647
|
+
print(f"Validation started for process index {process_index}")
|
648
|
+
# Validation step
|
649
|
+
self.validation_loop(
|
650
|
+
train_state,
|
651
|
+
val_step,
|
652
|
+
val_ds,
|
653
|
+
val_steps_per_epoch,
|
654
|
+
current_step,
|
655
|
+
)
|
656
|
+
print(colored(f"Validation done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
|
657
|
+
|
645
658
|
avg_loss = epoch_loss / train_steps_per_epoch
|
646
659
|
if avg_loss < self.best_loss:
|
647
660
|
self.best_loss = avg_loss
|
@@ -659,17 +672,6 @@ class SimpleTrainer:
|
|
659
672
|
}, step=current_step)
|
660
673
|
print(colored(f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}", 'green'))
|
661
674
|
|
662
|
-
if val_steps_per_epoch > 0:
|
663
|
-
print(f"Validation started for process index {process_index}")
|
664
|
-
# Validation step
|
665
|
-
self.validation_loop(
|
666
|
-
train_state,
|
667
|
-
val_step,
|
668
|
-
val_ds,
|
669
|
-
val_steps_per_epoch,
|
670
|
-
current_step,
|
671
|
-
)
|
672
|
-
print(colored(f"Validation done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
|
673
675
|
|
674
676
|
self.save(epochs)#
|
675
677
|
return self.state
|
@@ -15,7 +15,7 @@ flaxdiff/data/sources/videos.py,sha256=NkxwEruNpAwDCM53q4WurQ802gSjQMOqjNLxYOqjo
|
|
15
15
|
flaxdiff/data/sources/voxceleb2.py,sha256=BoKfat_hsw6ObDyyaiQmPbBzuFiqgCGlgAZmf-t5Iz8,18621
|
16
16
|
flaxdiff/inference/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
17
|
flaxdiff/inference/pipeline.py,sha256=8S30FAlXEjvrDd87H-qdD6biySQZ3cJUflU8gdmPxig,9223
|
18
|
-
flaxdiff/inference/utils.py,sha256=
|
18
|
+
flaxdiff/inference/utils.py,sha256=Dh0KawgvQrZxyqN_9wbsb7gUyvPRendwb-YtAU6zIBE,12606
|
19
19
|
flaxdiff/inputs/__init__.py,sha256=ybPjQsFAf5sqRVZG1sRiOl99EnwpI-NQ8HE3y7UbXmU,7197
|
20
20
|
flaxdiff/inputs/encoders.py,sha256=pjfbx4Rk7bLoE80MOfThZDm6YtsDncRekmn0Bmg_CwI,2963
|
21
21
|
flaxdiff/metrics/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -63,9 +63,9 @@ flaxdiff/schedulers/sqrt.py,sha256=mCd_szmOqF6vqQKiAiEOqV_3eBIPGYrW3VxK0o4rBuo,4
|
|
63
63
|
flaxdiff/trainer/__init__.py,sha256=xSoierfi26gxfgxlNnwvyyPmuPAJ--5i3mEHxt3S-AE,215
|
64
64
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=2FP2P-k9c0n_k3eT0trkq73dQrHRdBj9ObK1idcyhSw,6996
|
65
65
|
flaxdiff/trainer/diffusion_trainer.py,sha256=reQEVWKTqKAeyCMQ-curPOfSRmBKxKooK8EVtUuorcM,14599
|
66
|
-
flaxdiff/trainer/general_diffusion_trainer.py,sha256=
|
67
|
-
flaxdiff/trainer/simple_trainer.py,sha256=
|
68
|
-
flaxdiff-0.2.
|
69
|
-
flaxdiff-0.2.
|
70
|
-
flaxdiff-0.2.
|
71
|
-
flaxdiff-0.2.
|
66
|
+
flaxdiff/trainer/general_diffusion_trainer.py,sha256=gMo0OOz8EFKGfiqZnDwhVSxtk_IUMGUvyt5TTr_Hk8g,30168
|
67
|
+
flaxdiff/trainer/simple_trainer.py,sha256=nXYy9tadteG8N0RovpevPPEs6oeFvbr2gVq7Zot9l78,28754
|
68
|
+
flaxdiff-0.2.10.dist-info/METADATA,sha256=xsqksvLSps2a9nNdvZkguWvsC07kX8A3Z26DPTq-tGI,24058
|
69
|
+
flaxdiff-0.2.10.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
|
70
|
+
flaxdiff-0.2.10.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
71
|
+
flaxdiff-0.2.10.dist-info/RECORD,,
|
File without changes
|
File without changes
|