flaxdiff 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,7 +7,6 @@ import shutil
7
7
  import subprocess
8
8
  import numpy as np
9
9
  from typing import Tuple, Optional, Union, List
10
- from video_reader import PyVideoReader
11
10
  from .audio_utils import read_audio
12
11
 
13
12
  def get_video_fps(video_path: str):
@@ -113,6 +112,7 @@ def read_av_improved(
113
112
  Returns:
114
113
  Tuple of (audio_data, video_frames) where video_frames is a numpy array.
115
114
  """
115
+ from video_reader import PyVideoReader
116
116
  # Calculate time information for audio extraction
117
117
  start_time = start / fps if start > 0 else 0
118
118
  duration = None
@@ -484,11 +484,13 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
484
484
  def push_to_registry(
485
485
  self,
486
486
  registry_name: str = 'wandb-registry-model',
487
+ aliases: List[str] = [],
487
488
  ):
488
489
  """
489
490
  Push the model to wandb registry.
490
491
  Args:
491
492
  registry_name: Name of the model registry.
493
+ aliases: List of aliases for the model.
492
494
  """
493
495
  if self.wandb is None:
494
496
  raise ValueError("Wandb is not initialized. Cannot push to registry.")
@@ -502,6 +504,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
502
504
  artifact_or_path=latest_checkpoint_path,
503
505
  name=modelname,
504
506
  type="model",
507
+ aliases=['latest'] + aliases,
505
508
  )
506
509
 
507
510
  target_path = f"{registry_name}/{modelname}"
@@ -509,6 +512,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
509
512
  self.wandb.link_artifact(
510
513
  artifact=logged_artifact,
511
514
  target_path=target_path,
515
+ aliases=aliases,
512
516
  )
513
517
  print(f"Model pushed to registry at {target_path}")
514
518
  return logged_artifact
@@ -541,6 +545,15 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
541
545
  return best_runs, (min(lower_bound, upper_bound), max(lower_bound, upper_bound))
542
546
 
543
547
  def __compare_run_against_best__(self, top_k=2, metric="train/best_loss"):
548
+ """
549
+ Compare the current run against the best runs from the sweep.
550
+ Args:
551
+ top_k: Number of top runs to consider.
552
+ metric: Metric to compare against.
553
+ Returns:
554
+ is_good: Whether the current run is among the best.
555
+ is_best: Whether the current run is the best.
556
+ """
544
557
  # Get best runs
545
558
  best_runs, bounds = self.__get_best_sweep_runs__(metric=metric, top_k=top_k)
546
559
 
@@ -548,20 +561,18 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
548
561
  is_lower_better = "loss" in metric.lower()
549
562
 
550
563
  # Check if current run is one of the best
551
- current_run_metric = self.wandb.summary.get(metric, float('inf') if is_lower_better else float('-inf'))
552
-
553
- # Direct check if current run is in best runs
554
- for run in best_runs:
555
- if run.id == self.wandb.id:
556
- print(f"Current run {self.wandb.id} is one of the best runs.")
557
- return True
564
+ if metric == "train/best_loss":
565
+ current_run_metric = self.best_loss
566
+ else:
567
+ current_run_metric = self.wandb.summary.get(metric, float('inf') if is_lower_better else float('-inf'))
558
568
 
559
- # Backup check based on metric value
569
+ # Check based on bounds
560
570
  if (is_lower_better and current_run_metric < bounds[1]) or (not is_lower_better and current_run_metric > bounds[0]):
561
571
  print(f"Current run {self.wandb.id} meets performance criteria.")
562
- return True
572
+ is_best = (is_lower_better and current_run_metric < bounds[0]) or (not is_lower_better and current_run_metric > bounds[1])
573
+ return True, is_best
563
574
 
564
- return False
575
+ return False, False
565
576
 
566
577
  def save(self, epoch=0, step=0, state=None, rngstate=None):
567
578
  super().save(epoch=epoch, step=step, state=state, rngstate=rngstate)
@@ -569,9 +580,14 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
569
580
  if self.wandb is not None and hasattr(self, "wandb_sweep"):
570
581
  checkpoint = get_latest_checkpoint(self.checkpoint_path())
571
582
  try:
572
- if self.__compare_run_against_best__(top_k=5, metric="train/best_loss"):
573
- self.push_to_registry()
574
- print("Model pushed to registry successfully")
583
+ is_good, is_best = self.__compare_run_against_best__(top_k=5, metric="train/best_loss")
584
+ if is_good:
585
+ # Push to registry with appropriate aliases
586
+ aliases = []
587
+ if is_best:
588
+ aliases.append("best")
589
+ self.push_to_registry(aliases=aliases)
590
+ print("Model pushed to registry successfully with aliases:", aliases)
575
591
  else:
576
592
  print("Current run is not one of the best runs. Not saving model.")
577
593
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -7,7 +7,7 @@ flaxdiff/data/dataset_map.py,sha256=NrLG1XtIxy8GcCsZ-e6eascjgsP0Xq5lVA1z3HIIYyI,
7
7
  flaxdiff/data/online_loader.py,sha256=t1jEhdB6gWTlwx68ehj1ol_PrImbwXYiRlrJPCmNgCM,35701
8
8
  flaxdiff/data/sources/audio_utils.py,sha256=X27gG1yQt_abVOYgMtruYmZD7-8_uQCRhhTSpn4clkI,4514
9
9
  flaxdiff/data/sources/av_example.py,sha256=RIcbVKqckFqbfnV65NQotzIBxjdDuM67kD1nY8fqw5Q,3826
10
- flaxdiff/data/sources/av_utils.py,sha256=n2qwMBQGouoBH025vdE7gitWC6RduUommUrs-SPdWe4,24041
10
+ flaxdiff/data/sources/av_utils.py,sha256=LCr9MJNurOaoxY-sjzkLqJS_MlX0x3gRSlKAVIglAU0,24045
11
11
  flaxdiff/data/sources/base.py,sha256=uhF0odJSYRy0SLw1xnI9Q_q_xiVht2DmEYcX1j9AWT4,4246
12
12
  flaxdiff/data/sources/images.py,sha256=WpH4ywZhNol26peX3m6m5NrmDJ1K2s6fRcYHvOFlOk8,11102
13
13
  flaxdiff/data/sources/utils.py,sha256=kFzM4_kPoThbAu54ulABmEDAR33tR50NgzXIpC0Dzjk,7316
@@ -56,9 +56,9 @@ flaxdiff/schedulers/sqrt.py,sha256=mCd_szmOqF6vqQKiAiEOqV_3eBIPGYrW3VxK0o4rBuo,4
56
56
  flaxdiff/trainer/__init__.py,sha256=xSoierfi26gxfgxlNnwvyyPmuPAJ--5i3mEHxt3S-AE,215
57
57
  flaxdiff/trainer/autoencoder_trainer.py,sha256=2FP2P-k9c0n_k3eT0trkq73dQrHRdBj9ObK1idcyhSw,6996
58
58
  flaxdiff/trainer/diffusion_trainer.py,sha256=reQEVWKTqKAeyCMQ-curPOfSRmBKxKooK8EVtUuorcM,14599
59
- flaxdiff/trainer/general_diffusion_trainer.py,sha256=VQ5p2ZaTv2R1LM0Epz4e719_EfK2dh1eoKK3WIysIW0,24040
59
+ flaxdiff/trainer/general_diffusion_trainer.py,sha256=7VAeT3TzCDUyns8wdZbIwXJqDKx_FYSzq8toOkaeQMI,24802
60
60
  flaxdiff/trainer/simple_trainer.py,sha256=CF2mMcc6AtBgcR1XiqKevRL0paGS0S9ZJofCns32nRM,24214
61
- flaxdiff-0.2.0.dist-info/METADATA,sha256=1WLpd9RQy_mJE2E2uOdXptY5Fm3n_MTNcgZyBD7YmGw,23982
62
- flaxdiff-0.2.0.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
63
- flaxdiff-0.2.0.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
64
- flaxdiff-0.2.0.dist-info/RECORD,,
61
+ flaxdiff-0.2.2.dist-info/METADATA,sha256=pzYYdy1zK7lbaqSRdpopZHHYx7q3BP0DL11hGTOO7h4,23982
62
+ flaxdiff-0.2.2.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
63
+ flaxdiff-0.2.2.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
64
+ flaxdiff-0.2.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.1)
2
+ Generator: setuptools (79.0.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5