flaxdiff 0.2.0__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/PKG-INFO +1 -1
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/data/sources/av_utils.py +1 -1
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/trainer/general_diffusion_trainer.py +28 -13
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff.egg-info/PKG-INFO +1 -1
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/pyproject.toml +1 -1
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/README.md +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/__init__.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/data/__init__.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/data/benchmark_decord.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/data/dataloaders.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/data/dataset_map.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/data/online_loader.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/data/sources/audio_utils.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/data/sources/av_example.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/data/sources/base.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/data/sources/images.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/data/sources/utils.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/data/sources/videos.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/data/sources/voxceleb2.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/inference/__init__.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/inference/pipeline.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/inference/utils.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/inputs/__init__.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/inputs/encoders.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/metrics/inception.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/metrics/psnr.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/metrics/ssim.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/metrics/utils.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/models/__init__.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/models/attention.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/models/autoencoder/__init__.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/models/autoencoder/diffusers.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/models/common.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/models/favor_fastattn.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/models/general.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/models/simple_unet.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/models/simple_vit.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/models/unet_3d.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/models/unet_3d_blocks.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/predictors/__init__.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/samplers/__init__.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/samplers/common.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/samplers/ddim.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/samplers/ddpm.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/samplers/euler.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/samplers/heun_sampler.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/samplers/multistep_dpm.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/samplers/rk4_sampler.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/schedulers/__init__.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/schedulers/common.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/schedulers/continuous.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/schedulers/cosine.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/schedulers/discrete.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/schedulers/exp.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/schedulers/karras.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/schedulers/linear.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/schedulers/sqrt.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/trainer/__init__.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/trainer/diffusion_trainer.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/trainer/simple_trainer.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff/utils.py +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff.egg-info/SOURCES.txt +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff.egg-info/dependency_links.txt +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff.egg-info/requires.txt +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/flaxdiff.egg-info/top_level.txt +0 -0
- {flaxdiff-0.2.0 → flaxdiff-0.2.1}/setup.cfg +0 -0
@@ -7,7 +7,6 @@ import shutil
|
|
7
7
|
import subprocess
|
8
8
|
import numpy as np
|
9
9
|
from typing import Tuple, Optional, Union, List
|
10
|
-
from video_reader import PyVideoReader
|
11
10
|
from .audio_utils import read_audio
|
12
11
|
|
13
12
|
def get_video_fps(video_path: str):
|
@@ -113,6 +112,7 @@ def read_av_improved(
|
|
113
112
|
Returns:
|
114
113
|
Tuple of (audio_data, video_frames) where video_frames is a numpy array.
|
115
114
|
"""
|
115
|
+
from video_reader import PyVideoReader
|
116
116
|
# Calculate time information for audio extraction
|
117
117
|
start_time = start / fps if start > 0 else 0
|
118
118
|
duration = None
|
@@ -484,11 +484,13 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
484
484
|
def push_to_registry(
|
485
485
|
self,
|
486
486
|
registry_name: str = 'wandb-registry-model',
|
487
|
+
aliases: List[str] = ['latest'],
|
487
488
|
):
|
488
489
|
"""
|
489
490
|
Push the model to wandb registry.
|
490
491
|
Args:
|
491
492
|
registry_name: Name of the model registry.
|
493
|
+
aliases: List of aliases for the model.
|
492
494
|
"""
|
493
495
|
if self.wandb is None:
|
494
496
|
raise ValueError("Wandb is not initialized. Cannot push to registry.")
|
@@ -502,6 +504,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
502
504
|
artifact_or_path=latest_checkpoint_path,
|
503
505
|
name=modelname,
|
504
506
|
type="model",
|
507
|
+
aliases=aliases,
|
505
508
|
)
|
506
509
|
|
507
510
|
target_path = f"{registry_name}/{modelname}"
|
@@ -541,6 +544,15 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
541
544
|
return best_runs, (min(lower_bound, upper_bound), max(lower_bound, upper_bound))
|
542
545
|
|
543
546
|
def __compare_run_against_best__(self, top_k=2, metric="train/best_loss"):
|
547
|
+
"""
|
548
|
+
Compare the current run against the best runs from the sweep.
|
549
|
+
Args:
|
550
|
+
top_k: Number of top runs to consider.
|
551
|
+
metric: Metric to compare against.
|
552
|
+
Returns:
|
553
|
+
is_good: Whether the current run is among the best.
|
554
|
+
is_best: Whether the current run is the best.
|
555
|
+
"""
|
544
556
|
# Get best runs
|
545
557
|
best_runs, bounds = self.__get_best_sweep_runs__(metric=metric, top_k=top_k)
|
546
558
|
|
@@ -548,20 +560,18 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
548
560
|
is_lower_better = "loss" in metric.lower()
|
549
561
|
|
550
562
|
# Check if current run is one of the best
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
if run.id == self.wandb.id:
|
556
|
-
print(f"Current run {self.wandb.id} is one of the best runs.")
|
557
|
-
return True
|
563
|
+
if metric == "train/best_loss":
|
564
|
+
current_run_metric = self.best_loss
|
565
|
+
else:
|
566
|
+
current_run_metric = self.wandb.summary.get(metric, float('inf') if is_lower_better else float('-inf'))
|
558
567
|
|
559
|
-
#
|
568
|
+
# Check based on bounds
|
560
569
|
if (is_lower_better and current_run_metric < bounds[1]) or (not is_lower_better and current_run_metric > bounds[0]):
|
561
570
|
print(f"Current run {self.wandb.id} meets performance criteria.")
|
562
|
-
|
571
|
+
is_best = (is_lower_better and current_run_metric < bounds[0]) or (not is_lower_better and current_run_metric > bounds[1])
|
572
|
+
return True, is_best
|
563
573
|
|
564
|
-
return False
|
574
|
+
return False, False
|
565
575
|
|
566
576
|
def save(self, epoch=0, step=0, state=None, rngstate=None):
|
567
577
|
super().save(epoch=epoch, step=step, state=state, rngstate=rngstate)
|
@@ -569,9 +579,14 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
569
579
|
if self.wandb is not None and hasattr(self, "wandb_sweep"):
|
570
580
|
checkpoint = get_latest_checkpoint(self.checkpoint_path())
|
571
581
|
try:
|
572
|
-
|
573
|
-
|
574
|
-
|
582
|
+
is_good, is_best = self.__compare_run_against_best__(top_k=5, metric="train/best_loss")
|
583
|
+
if is_good:
|
584
|
+
# Push to registry with appropriate aliases
|
585
|
+
aliases = ["latest"]
|
586
|
+
if is_best:
|
587
|
+
aliases.append("best")
|
588
|
+
self.push_to_registry(aliases=aliases)
|
589
|
+
print("Model pushed to registry successfully with aliases:", aliases)
|
575
590
|
else:
|
576
591
|
print("Current run is not one of the best runs. Not saving model.")
|
577
592
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|