flaxdiff 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/sources/av_utils.py +1 -1
- flaxdiff/trainer/general_diffusion_trainer.py +29 -13
- {flaxdiff-0.2.0.dist-info → flaxdiff-0.2.2.dist-info}/METADATA +1 -1
- {flaxdiff-0.2.0.dist-info → flaxdiff-0.2.2.dist-info}/RECORD +6 -6
- {flaxdiff-0.2.0.dist-info → flaxdiff-0.2.2.dist-info}/WHEEL +1 -1
- {flaxdiff-0.2.0.dist-info → flaxdiff-0.2.2.dist-info}/top_level.txt +0 -0
@@ -7,7 +7,6 @@ import shutil
|
|
7
7
|
import subprocess
|
8
8
|
import numpy as np
|
9
9
|
from typing import Tuple, Optional, Union, List
|
10
|
-
from video_reader import PyVideoReader
|
11
10
|
from .audio_utils import read_audio
|
12
11
|
|
13
12
|
def get_video_fps(video_path: str):
|
@@ -113,6 +112,7 @@ def read_av_improved(
|
|
113
112
|
Returns:
|
114
113
|
Tuple of (audio_data, video_frames) where video_frames is a numpy array.
|
115
114
|
"""
|
115
|
+
from video_reader import PyVideoReader
|
116
116
|
# Calculate time information for audio extraction
|
117
117
|
start_time = start / fps if start > 0 else 0
|
118
118
|
duration = None
|
@@ -484,11 +484,13 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
484
484
|
def push_to_registry(
|
485
485
|
self,
|
486
486
|
registry_name: str = 'wandb-registry-model',
|
487
|
+
aliases: List[str] = [],
|
487
488
|
):
|
488
489
|
"""
|
489
490
|
Push the model to wandb registry.
|
490
491
|
Args:
|
491
492
|
registry_name: Name of the model registry.
|
493
|
+
aliases: List of aliases for the model.
|
492
494
|
"""
|
493
495
|
if self.wandb is None:
|
494
496
|
raise ValueError("Wandb is not initialized. Cannot push to registry.")
|
@@ -502,6 +504,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
502
504
|
artifact_or_path=latest_checkpoint_path,
|
503
505
|
name=modelname,
|
504
506
|
type="model",
|
507
|
+
aliases=['latest'] + aliases,
|
505
508
|
)
|
506
509
|
|
507
510
|
target_path = f"{registry_name}/{modelname}"
|
@@ -509,6 +512,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
509
512
|
self.wandb.link_artifact(
|
510
513
|
artifact=logged_artifact,
|
511
514
|
target_path=target_path,
|
515
|
+
aliases=aliases,
|
512
516
|
)
|
513
517
|
print(f"Model pushed to registry at {target_path}")
|
514
518
|
return logged_artifact
|
@@ -541,6 +545,15 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
541
545
|
return best_runs, (min(lower_bound, upper_bound), max(lower_bound, upper_bound))
|
542
546
|
|
543
547
|
def __compare_run_against_best__(self, top_k=2, metric="train/best_loss"):
|
548
|
+
"""
|
549
|
+
Compare the current run against the best runs from the sweep.
|
550
|
+
Args:
|
551
|
+
top_k: Number of top runs to consider.
|
552
|
+
metric: Metric to compare against.
|
553
|
+
Returns:
|
554
|
+
is_good: Whether the current run is among the best.
|
555
|
+
is_best: Whether the current run is the best.
|
556
|
+
"""
|
544
557
|
# Get best runs
|
545
558
|
best_runs, bounds = self.__get_best_sweep_runs__(metric=metric, top_k=top_k)
|
546
559
|
|
@@ -548,20 +561,18 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
548
561
|
is_lower_better = "loss" in metric.lower()
|
549
562
|
|
550
563
|
# Check if current run is one of the best
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
if run.id == self.wandb.id:
|
556
|
-
print(f"Current run {self.wandb.id} is one of the best runs.")
|
557
|
-
return True
|
564
|
+
if metric == "train/best_loss":
|
565
|
+
current_run_metric = self.best_loss
|
566
|
+
else:
|
567
|
+
current_run_metric = self.wandb.summary.get(metric, float('inf') if is_lower_better else float('-inf'))
|
558
568
|
|
559
|
-
#
|
569
|
+
# Check based on bounds
|
560
570
|
if (is_lower_better and current_run_metric < bounds[1]) or (not is_lower_better and current_run_metric > bounds[0]):
|
561
571
|
print(f"Current run {self.wandb.id} meets performance criteria.")
|
562
|
-
|
572
|
+
is_best = (is_lower_better and current_run_metric < bounds[0]) or (not is_lower_better and current_run_metric > bounds[1])
|
573
|
+
return True, is_best
|
563
574
|
|
564
|
-
return False
|
575
|
+
return False, False
|
565
576
|
|
566
577
|
def save(self, epoch=0, step=0, state=None, rngstate=None):
|
567
578
|
super().save(epoch=epoch, step=step, state=state, rngstate=rngstate)
|
@@ -569,9 +580,14 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
569
580
|
if self.wandb is not None and hasattr(self, "wandb_sweep"):
|
570
581
|
checkpoint = get_latest_checkpoint(self.checkpoint_path())
|
571
582
|
try:
|
572
|
-
|
573
|
-
|
574
|
-
|
583
|
+
is_good, is_best = self.__compare_run_against_best__(top_k=5, metric="train/best_loss")
|
584
|
+
if is_good:
|
585
|
+
# Push to registry with appropriate aliases
|
586
|
+
aliases = []
|
587
|
+
if is_best:
|
588
|
+
aliases.append("best")
|
589
|
+
self.push_to_registry(aliases=aliases)
|
590
|
+
print("Model pushed to registry successfully with aliases:", aliases)
|
575
591
|
else:
|
576
592
|
print("Current run is not one of the best runs. Not saving model.")
|
577
593
|
|
@@ -7,7 +7,7 @@ flaxdiff/data/dataset_map.py,sha256=NrLG1XtIxy8GcCsZ-e6eascjgsP0Xq5lVA1z3HIIYyI,
|
|
7
7
|
flaxdiff/data/online_loader.py,sha256=t1jEhdB6gWTlwx68ehj1ol_PrImbwXYiRlrJPCmNgCM,35701
|
8
8
|
flaxdiff/data/sources/audio_utils.py,sha256=X27gG1yQt_abVOYgMtruYmZD7-8_uQCRhhTSpn4clkI,4514
|
9
9
|
flaxdiff/data/sources/av_example.py,sha256=RIcbVKqckFqbfnV65NQotzIBxjdDuM67kD1nY8fqw5Q,3826
|
10
|
-
flaxdiff/data/sources/av_utils.py,sha256=
|
10
|
+
flaxdiff/data/sources/av_utils.py,sha256=LCr9MJNurOaoxY-sjzkLqJS_MlX0x3gRSlKAVIglAU0,24045
|
11
11
|
flaxdiff/data/sources/base.py,sha256=uhF0odJSYRy0SLw1xnI9Q_q_xiVht2DmEYcX1j9AWT4,4246
|
12
12
|
flaxdiff/data/sources/images.py,sha256=WpH4ywZhNol26peX3m6m5NrmDJ1K2s6fRcYHvOFlOk8,11102
|
13
13
|
flaxdiff/data/sources/utils.py,sha256=kFzM4_kPoThbAu54ulABmEDAR33tR50NgzXIpC0Dzjk,7316
|
@@ -56,9 +56,9 @@ flaxdiff/schedulers/sqrt.py,sha256=mCd_szmOqF6vqQKiAiEOqV_3eBIPGYrW3VxK0o4rBuo,4
|
|
56
56
|
flaxdiff/trainer/__init__.py,sha256=xSoierfi26gxfgxlNnwvyyPmuPAJ--5i3mEHxt3S-AE,215
|
57
57
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=2FP2P-k9c0n_k3eT0trkq73dQrHRdBj9ObK1idcyhSw,6996
|
58
58
|
flaxdiff/trainer/diffusion_trainer.py,sha256=reQEVWKTqKAeyCMQ-curPOfSRmBKxKooK8EVtUuorcM,14599
|
59
|
-
flaxdiff/trainer/general_diffusion_trainer.py,sha256=
|
59
|
+
flaxdiff/trainer/general_diffusion_trainer.py,sha256=7VAeT3TzCDUyns8wdZbIwXJqDKx_FYSzq8toOkaeQMI,24802
|
60
60
|
flaxdiff/trainer/simple_trainer.py,sha256=CF2mMcc6AtBgcR1XiqKevRL0paGS0S9ZJofCns32nRM,24214
|
61
|
-
flaxdiff-0.2.
|
62
|
-
flaxdiff-0.2.
|
63
|
-
flaxdiff-0.2.
|
64
|
-
flaxdiff-0.2.
|
61
|
+
flaxdiff-0.2.2.dist-info/METADATA,sha256=pzYYdy1zK7lbaqSRdpopZHHYx7q3BP0DL11hGTOO7h4,23982
|
62
|
+
flaxdiff-0.2.2.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
|
63
|
+
flaxdiff-0.2.2.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
64
|
+
flaxdiff-0.2.2.dist-info/RECORD,,
|
File without changes
|