flaxdiff 0.1.37.1__tar.gz → 0.1.37.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/PKG-INFO +1 -1
  2. flaxdiff-0.1.37.3/flaxdiff/metrics/psnr.py +0 -0
  3. flaxdiff-0.1.37.3/flaxdiff/metrics/ssim.py +0 -0
  4. flaxdiff-0.1.37.3/flaxdiff/samplers/__init__.py +7 -0
  5. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/samplers/common.py +7 -9
  6. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff.egg-info/PKG-INFO +1 -1
  7. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff.egg-info/SOURCES.txt +2 -0
  8. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/pyproject.toml +1 -1
  9. flaxdiff-0.1.37.1/flaxdiff/samplers/__init__.py +0 -7
  10. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/README.md +0 -0
  11. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/__init__.py +0 -0
  12. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/data/__init__.py +0 -0
  13. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/data/dataset_map.py +0 -0
  14. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/data/datasets.py +0 -0
  15. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/data/online_loader.py +0 -0
  16. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/data/sources/gcs.py +0 -0
  17. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/data/sources/tfds.py +0 -0
  18. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/metrics/inception.py +0 -0
  19. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/metrics/utils.py +0 -0
  20. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/models/__init__.py +0 -0
  21. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/models/attention.py +0 -0
  22. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/models/autoencoder/__init__.py +0 -0
  23. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
  24. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/models/autoencoder/diffusers.py +0 -0
  25. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
  26. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/models/common.py +0 -0
  27. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/models/favor_fastattn.py +0 -0
  28. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/models/simple_unet.py +0 -0
  29. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/models/simple_vit.py +0 -0
  30. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/predictors/__init__.py +0 -0
  31. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/samplers/ddim.py +0 -0
  32. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/samplers/ddpm.py +0 -0
  33. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/samplers/euler.py +0 -0
  34. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/samplers/heun_sampler.py +0 -0
  35. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/samplers/multistep_dpm.py +0 -0
  36. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/samplers/rk4_sampler.py +0 -0
  37. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/schedulers/__init__.py +0 -0
  38. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/schedulers/common.py +0 -0
  39. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/schedulers/continuous.py +0 -0
  40. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/schedulers/cosine.py +0 -0
  41. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/schedulers/discrete.py +0 -0
  42. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/schedulers/exp.py +0 -0
  43. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/schedulers/karras.py +0 -0
  44. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/schedulers/linear.py +0 -0
  45. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/schedulers/sqrt.py +0 -0
  46. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/trainer/__init__.py +0 -0
  47. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
  48. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/trainer/diffusion_trainer.py +0 -0
  49. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/trainer/simple_trainer.py +0 -0
  50. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/trainer/video_diffusion_trainer.py +0 -0
  51. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff/utils.py +0 -0
  52. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff.egg-info/dependency_links.txt +0 -0
  53. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff.egg-info/requires.txt +0 -0
  54. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/flaxdiff.egg-info/top_level.txt +0 -0
  55. {flaxdiff-0.1.37.1 → flaxdiff-0.1.37.3}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.37.1
3
+ Version: 0.1.37.3
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
File without changes
File without changes
@@ -0,0 +1,7 @@
1
+ from .common import DiffusionSampler
2
+ from .ddim import DDIMSampler
3
+ from .ddpm import DDPMSampler, SimpleDDPMSampler
4
+ from .euler import EulerSampler, SimplifiedEulerSampler, EulerAncestralSampler
5
+ from .heun_sampler import HeunSampler
6
+ from .rk4_sampler import RK4Sampler
7
+ from .multistep_dpm import MultiStepDPM
@@ -2,20 +2,15 @@ from flax import linen as nn
2
2
  import jax
3
3
  import jax.numpy as jnp
4
4
  import tqdm
5
- from typing import Union
5
+ from typing import Union, Type
6
6
  from ..schedulers import NoiseScheduler
7
7
  from ..utils import RandomMarkovState, MarkovState, clip_images
8
8
  from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
9
9
 
10
10
  class DiffusionSampler():
11
- model:nn.Module
12
- noise_schedule:NoiseScheduler
13
- params:dict
14
- model_output_transform:DiffusionPredictionTransform
15
-
16
11
  def __init__(self, model:nn.Module, params:dict,
17
12
  noise_schedule:NoiseScheduler,
18
- model_output_transform:DiffusionPredictionTransform=EpsilonPredictionTransform(),
13
+ model_output_transform:DiffusionPredictionTransform,
19
14
  guidance_scale:float = 0.0,
20
15
  null_labels_seq:jax.Array=None,
21
16
  autoencoder=None,
@@ -122,9 +117,11 @@ class DiffusionSampler():
122
117
  end_step:int = 0,
123
118
  steps_override=None,
124
119
  priors=None,
125
- rngstate:RandomMarkovState=RandomMarkovState(jax.random.PRNGKey(42)),
120
+ rngstate:RandomMarkovState=None,
126
121
  model_conditioning_inputs:tuple=()
127
122
  ) -> jnp.ndarray:
123
+ if rngstate is None:
124
+ rngstate = RandomMarkovState(jax.random.PRNGKey(42))
128
125
  if priors is None:
129
126
  rngstate, newrngs = rngstate.get_random_key()
130
127
  samples = self.get_initial_samples(num_images, newrngs, start_step)
@@ -169,4 +166,5 @@ class DiffusionSampler():
169
166
  if self.autoencoder is not None:
170
167
  samples = self.autoencoder.decode(samples)
171
168
  samples = clip_images(samples)
172
- return samples
169
+ return samples
170
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.37.1
3
+ Version: 0.1.37.3
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -14,6 +14,8 @@ flaxdiff/data/online_loader.py
14
14
  flaxdiff/data/sources/gcs.py
15
15
  flaxdiff/data/sources/tfds.py
16
16
  flaxdiff/metrics/inception.py
17
+ flaxdiff/metrics/psnr.py
18
+ flaxdiff/metrics/ssim.py
17
19
  flaxdiff/metrics/utils.py
18
20
  flaxdiff/models/__init__.py
19
21
  flaxdiff/models/attention.py
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "flaxdiff"
7
- version = "0.1.37.1"
7
+ version = "0.1.37.3"
8
8
  description = "A versatile and easy to understand Diffusion library"
9
9
  readme = "README.md"
10
10
  authors = [
@@ -1,7 +0,0 @@
1
- from .common import *
2
- from .ddim import *
3
- from .ddpm import *
4
- from .euler import *
5
- from .heun_sampler import *
6
- from .rk4_sampler import *
7
- from .multistep_dpm import *
File without changes
File without changes