flaxdiff 0.1.37.2__tar.gz → 0.1.37.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/PKG-INFO +1 -1
  2. flaxdiff-0.1.37.3/flaxdiff/metrics/psnr.py +0 -0
  3. flaxdiff-0.1.37.3/flaxdiff/metrics/ssim.py +0 -0
  4. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/samplers/common.py +7 -9
  5. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff.egg-info/PKG-INFO +1 -1
  6. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff.egg-info/SOURCES.txt +2 -0
  7. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/pyproject.toml +1 -1
  8. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/README.md +0 -0
  9. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/__init__.py +0 -0
  10. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/data/__init__.py +0 -0
  11. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/data/dataset_map.py +0 -0
  12. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/data/datasets.py +0 -0
  13. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/data/online_loader.py +0 -0
  14. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/data/sources/gcs.py +0 -0
  15. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/data/sources/tfds.py +0 -0
  16. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/metrics/inception.py +0 -0
  17. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/metrics/utils.py +0 -0
  18. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/models/__init__.py +0 -0
  19. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/models/attention.py +0 -0
  20. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/models/autoencoder/__init__.py +0 -0
  21. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
  22. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/models/autoencoder/diffusers.py +0 -0
  23. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
  24. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/models/common.py +0 -0
  25. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/models/favor_fastattn.py +0 -0
  26. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/models/simple_unet.py +0 -0
  27. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/models/simple_vit.py +0 -0
  28. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/predictors/__init__.py +0 -0
  29. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/samplers/__init__.py +0 -0
  30. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/samplers/ddim.py +0 -0
  31. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/samplers/ddpm.py +0 -0
  32. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/samplers/euler.py +0 -0
  33. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/samplers/heun_sampler.py +0 -0
  34. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/samplers/multistep_dpm.py +0 -0
  35. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/samplers/rk4_sampler.py +0 -0
  36. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/schedulers/__init__.py +0 -0
  37. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/schedulers/common.py +0 -0
  38. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/schedulers/continuous.py +0 -0
  39. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/schedulers/cosine.py +0 -0
  40. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/schedulers/discrete.py +0 -0
  41. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/schedulers/exp.py +0 -0
  42. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/schedulers/karras.py +0 -0
  43. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/schedulers/linear.py +0 -0
  44. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/schedulers/sqrt.py +0 -0
  45. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/trainer/__init__.py +0 -0
  46. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
  47. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/trainer/diffusion_trainer.py +0 -0
  48. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/trainer/simple_trainer.py +0 -0
  49. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/trainer/video_diffusion_trainer.py +0 -0
  50. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/utils.py +0 -0
  51. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff.egg-info/dependency_links.txt +0 -0
  52. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff.egg-info/requires.txt +0 -0
  53. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff.egg-info/top_level.txt +0 -0
  54. {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.37.2
3
+ Version: 0.1.37.3
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
File without changes
File without changes
@@ -2,20 +2,15 @@ from flax import linen as nn
2
2
  import jax
3
3
  import jax.numpy as jnp
4
4
  import tqdm
5
- from typing import Union
5
+ from typing import Union, Type
6
6
  from ..schedulers import NoiseScheduler
7
7
  from ..utils import RandomMarkovState, MarkovState, clip_images
8
8
  from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
9
9
 
10
10
  class DiffusionSampler():
11
- model:nn.Module
12
- noise_schedule:NoiseScheduler
13
- params:dict
14
- model_output_transform:DiffusionPredictionTransform
15
-
16
11
  def __init__(self, model:nn.Module, params:dict,
17
12
  noise_schedule:NoiseScheduler,
18
- model_output_transform:DiffusionPredictionTransform=EpsilonPredictionTransform(),
13
+ model_output_transform:DiffusionPredictionTransform,
19
14
  guidance_scale:float = 0.0,
20
15
  null_labels_seq:jax.Array=None,
21
16
  autoencoder=None,
@@ -122,9 +117,11 @@ class DiffusionSampler():
122
117
  end_step:int = 0,
123
118
  steps_override=None,
124
119
  priors=None,
125
- rngstate:RandomMarkovState=RandomMarkovState(jax.random.PRNGKey(42)),
120
+ rngstate:RandomMarkovState=None,
126
121
  model_conditioning_inputs:tuple=()
127
122
  ) -> jnp.ndarray:
123
+ if rngstate is None:
124
+ rngstate = RandomMarkovState(jax.random.PRNGKey(42))
128
125
  if priors is None:
129
126
  rngstate, newrngs = rngstate.get_random_key()
130
127
  samples = self.get_initial_samples(num_images, newrngs, start_step)
@@ -169,4 +166,5 @@ class DiffusionSampler():
169
166
  if self.autoencoder is not None:
170
167
  samples = self.autoencoder.decode(samples)
171
168
  samples = clip_images(samples)
172
- return samples
169
+ return samples
170
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.37.2
3
+ Version: 0.1.37.3
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -14,6 +14,8 @@ flaxdiff/data/online_loader.py
14
14
  flaxdiff/data/sources/gcs.py
15
15
  flaxdiff/data/sources/tfds.py
16
16
  flaxdiff/metrics/inception.py
17
+ flaxdiff/metrics/psnr.py
18
+ flaxdiff/metrics/ssim.py
17
19
  flaxdiff/metrics/utils.py
18
20
  flaxdiff/models/__init__.py
19
21
  flaxdiff/models/attention.py
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "flaxdiff"
7
- version = "0.1.37.2"
7
+ version = "0.1.37.3"
8
8
  description = "A versatile and easy to understand Diffusion library"
9
9
  readme = "README.md"
10
10
  authors = [
File without changes
File without changes