flaxdiff 0.2.6__tar.gz → 0.2.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/PKG-INFO +1 -1
  2. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/data/sources/images.py +3 -1
  3. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/inference/pipeline.py +10 -5
  4. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/inference/utils.py +2 -2
  5. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff.egg-info/PKG-INFO +1 -1
  6. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/pyproject.toml +1 -1
  7. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/README.md +0 -0
  8. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/__init__.py +0 -0
  9. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/data/__init__.py +0 -0
  10. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/data/benchmark_decord.py +0 -0
  11. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/data/dataloaders.py +0 -0
  12. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/data/dataset_map.py +0 -0
  13. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/data/online_loader.py +0 -0
  14. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/data/sources/audio_utils.py +0 -0
  15. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/data/sources/av_example.py +0 -0
  16. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/data/sources/av_utils.py +0 -0
  17. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/data/sources/base.py +0 -0
  18. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/data/sources/utils.py +0 -0
  19. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/data/sources/videos.py +0 -0
  20. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/data/sources/voxceleb2.py +0 -0
  21. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/inference/__init__.py +0 -0
  22. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/inputs/__init__.py +0 -0
  23. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/inputs/encoders.py +0 -0
  24. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/metrics/__init__.py +0 -0
  25. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/metrics/common.py +0 -0
  26. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/metrics/images.py +0 -0
  27. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/metrics/inception.py +0 -0
  28. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/metrics/psnr.py +0 -0
  29. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/metrics/ssim.py +0 -0
  30. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/metrics/utils.py +0 -0
  31. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/models/__init__.py +0 -0
  32. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/models/attention.py +0 -0
  33. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/models/autoencoder/__init__.py +0 -0
  34. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
  35. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/models/autoencoder/diffusers.py +0 -0
  36. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
  37. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/models/better_uvit.py +0 -0
  38. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/models/common.py +0 -0
  39. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/models/favor_fastattn.py +0 -0
  40. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/models/general.py +0 -0
  41. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/models/simple_unet.py +0 -0
  42. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/models/simple_vit.py +0 -0
  43. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/models/unet_3d.py +0 -0
  44. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/models/unet_3d_blocks.py +0 -0
  45. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/predictors/__init__.py +0 -0
  46. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/samplers/__init__.py +0 -0
  47. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/samplers/common.py +0 -0
  48. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/samplers/ddim.py +0 -0
  49. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/samplers/ddpm.py +0 -0
  50. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/samplers/euler.py +0 -0
  51. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/samplers/heun_sampler.py +0 -0
  52. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/samplers/multistep_dpm.py +0 -0
  53. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/samplers/rk4_sampler.py +0 -0
  54. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/schedulers/__init__.py +0 -0
  55. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/schedulers/common.py +0 -0
  56. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/schedulers/continuous.py +0 -0
  57. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/schedulers/cosine.py +0 -0
  58. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/schedulers/discrete.py +0 -0
  59. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/schedulers/exp.py +0 -0
  60. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/schedulers/karras.py +0 -0
  61. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/schedulers/linear.py +0 -0
  62. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/schedulers/sqrt.py +0 -0
  63. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/trainer/__init__.py +0 -0
  64. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
  65. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/trainer/diffusion_trainer.py +0 -0
  66. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/trainer/general_diffusion_trainer.py +0 -0
  67. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/trainer/simple_trainer.py +0 -0
  68. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff/utils.py +0 -0
  69. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff.egg-info/SOURCES.txt +0 -0
  70. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff.egg-info/dependency_links.txt +0 -0
  71. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff.egg-info/requires.txt +0 -0
  72. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/flaxdiff.egg-info/top_level.txt +0 -0
  73. {flaxdiff-0.2.6 → flaxdiff-0.2.7}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.2.6
3
+ Version: 0.2.7
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -113,12 +113,14 @@ class ImageTFDSSource(DataSource):
113
113
  class ImageTFDSAugmenter(DataAugmenter):
114
114
  """Augmenter for TFDS image datasets."""
115
115
 
116
- def __init__(self, label_path: str = "/home/mrwhite0racle/tensorflow_datasets/oxford_flowers102/2.1.1/label.labels.txt"):
116
+ def __init__(self, label_path: str = None):
117
117
  """Initialize a TFDS image augmenter.
118
118
 
119
119
  Args:
120
120
  label_path: Path to the labels file for datasets like Oxford Flowers.
121
121
  """
122
+ if label_path is None:
123
+ label_path = os.path.join(os.path.expanduser("~"), "tensorflow_datasets/oxford_flowers102/2.1.1/label.labels.txt")
122
124
  self.label_path = label_path
123
125
 
124
126
  def create_transform(self, image_scale: int = 256, method: Any = None) -> Callable[[], pygrain.MapTransform]:
@@ -20,11 +20,12 @@ from flaxdiff.models.autoencoder import AutoEncoder
20
20
  from flaxdiff.inputs import DiffusionInputConfig
21
21
  from flaxdiff.utils import defaultTextEncodeModel, RandomMarkovState
22
22
  from flaxdiff.samplers.euler import EulerAncestralSampler
23
- from .utils import parse_config, load_from_wandb_run, load_from_wandb_registry
23
+ from flaxdiff.inference.utils import parse_config, load_from_wandb_run, load_from_wandb_registry
24
24
 
25
25
  @dataclass
26
26
  class InferencePipeline:
27
27
  """Inference pipeline for a general model."""
28
+ name: str = None
28
29
  model: nn.Module = None
29
30
  state: SimpleTrainState = None
30
31
  best_state: SimpleTrainState = None
@@ -44,6 +45,7 @@ class DiffusionInferencePipeline(InferencePipeline):
44
45
  This pipeline handles loading models from wandb and generating samples using the
45
46
  DiffusionSampler from FlaxDiff.
46
47
  """
48
+ artifact: Any = None
47
49
  state: TrainState = None
48
50
  best_state: TrainState = None
49
51
  rngstate: Optional[RandomMarkovState] = None
@@ -53,7 +55,6 @@ class DiffusionInferencePipeline(InferencePipeline):
53
55
  input_config: DiffusionInputConfig = None
54
56
  samplers: Dict[Type[DiffusionSampler], Dict[float, DiffusionSampler]] = field(default_factory=dict)
55
57
  config: Dict[str, Any] = field(default_factory=dict)
56
- wandb_run = None
57
58
 
58
59
  @classmethod
59
60
  def from_wandb_run(
@@ -76,7 +77,7 @@ class DiffusionInferencePipeline(InferencePipeline):
76
77
  Returns:
77
78
  DiffusionInferencePipeline instance
78
79
  """
79
- states, config, run = load_from_wandb_run(
80
+ states, config, run, artifact = load_from_wandb_run(
80
81
  wandb_run,
81
82
  project=project,
82
83
  entity=entity,
@@ -95,6 +96,7 @@ class DiffusionInferencePipeline(InferencePipeline):
95
96
  best_state=best_state,
96
97
  rngstate=RandomMarkovState(jax.random.PRNGKey(42)),
97
98
  run=run,
99
+ artifact=artifact,
98
100
  )
99
101
  return pipeline
100
102
 
@@ -119,7 +121,7 @@ class DiffusionInferencePipeline(InferencePipeline):
119
121
  Returns:
120
122
  DiffusionInferencePipeline instance
121
123
  """
122
- states, config, run = load_from_wandb_registry(
124
+ states, config, run, artifact = load_from_wandb_registry(
123
125
  modelname=modelname,
124
126
  project=project,
125
127
  entity=entity,
@@ -140,6 +142,7 @@ class DiffusionInferencePipeline(InferencePipeline):
140
142
  best_state=best_state,
141
143
  rngstate=RandomMarkovState(jax.random.PRNGKey(42)),
142
144
  run=run,
145
+ artifact=artifact,
143
146
  )
144
147
  return pipeline
145
148
 
@@ -151,11 +154,14 @@ class DiffusionInferencePipeline(InferencePipeline):
151
154
  best_state: Optional[Dict[str, Any]] = None,
152
155
  rngstate: Optional[RandomMarkovState] = None,
153
156
  run=None,
157
+ artifact=None,
154
158
  ):
155
159
  if rngstate is None:
156
160
  rngstate = RandomMarkovState(jax.random.PRNGKey(42))
157
161
  # Build and return pipeline
158
162
  return cls(
163
+ name=run.name if run else None,
164
+ artifact=artifact,
159
165
  model=config['model'],
160
166
  state=state,
161
167
  best_state=best_state,
@@ -165,7 +171,6 @@ class DiffusionInferencePipeline(InferencePipeline):
165
171
  autoencoder=config['autoencoder'],
166
172
  input_config=config['input_config'],
167
173
  config=config,
168
- wandb_run=run,
169
174
  )
170
175
 
171
176
  def get_sampler(
@@ -292,7 +292,7 @@ def load_from_wandb_run(
292
292
  config = run.config
293
293
  except Exception as e:
294
294
  print(f"Warning: Failed to load model from wandb: {e}")
295
- return states, config, run
295
+ return states, config, run, artifact
296
296
 
297
297
  def load_from_wandb_registry(
298
298
  modelname: str,
@@ -318,4 +318,4 @@ def load_from_wandb_registry(
318
318
  config = run.config
319
319
  except Exception as e:
320
320
  print(f"Warning: Failed to load model from wandb: {e}")
321
- return states, config, run
321
+ return states, config, run, artifact
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.2.6
3
+ Version: 0.2.7
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "flaxdiff"
7
- version = "0.2.6"
7
+ version = "0.2.7"
8
8
  description = "A versatile and easy to understand Diffusion library"
9
9
  readme = "README.md"
10
10
  authors = [
File without changes
File without changes
File without changes
File without changes