flaxdiff 0.1.37.1__py3-none-any.whl → 0.1.37.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
File without changes
@@ -1,7 +1,7 @@
1
- from .common import *
2
- from .ddim import *
3
- from .ddpm import *
4
- from .euler import *
5
- from .heun_sampler import *
6
- from .rk4_sampler import *
7
- from .multistep_dpm import *
1
+ from .common import DiffusionSampler
2
+ from .ddim import DDIMSampler
3
+ from .ddpm import DDPMSampler, SimpleDDPMSampler
4
+ from .euler import EulerSampler, SimplifiedEulerSampler, EulerAncestralSampler
5
+ from .heun_sampler import HeunSampler
6
+ from .rk4_sampler import RK4Sampler
7
+ from .multistep_dpm import MultiStepDPM
@@ -2,20 +2,15 @@ from flax import linen as nn
2
2
  import jax
3
3
  import jax.numpy as jnp
4
4
  import tqdm
5
- from typing import Union
5
+ from typing import Union, Type
6
6
  from ..schedulers import NoiseScheduler
7
7
  from ..utils import RandomMarkovState, MarkovState, clip_images
8
8
  from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
9
9
 
10
10
  class DiffusionSampler():
11
- model:nn.Module
12
- noise_schedule:NoiseScheduler
13
- params:dict
14
- model_output_transform:DiffusionPredictionTransform
15
-
16
11
  def __init__(self, model:nn.Module, params:dict,
17
12
  noise_schedule:NoiseScheduler,
18
- model_output_transform:DiffusionPredictionTransform=EpsilonPredictionTransform(),
13
+ model_output_transform:DiffusionPredictionTransform,
19
14
  guidance_scale:float = 0.0,
20
15
  null_labels_seq:jax.Array=None,
21
16
  autoencoder=None,
@@ -122,9 +117,11 @@ class DiffusionSampler():
122
117
  end_step:int = 0,
123
118
  steps_override=None,
124
119
  priors=None,
125
- rngstate:RandomMarkovState=RandomMarkovState(jax.random.PRNGKey(42)),
120
+ rngstate:RandomMarkovState=None,
126
121
  model_conditioning_inputs:tuple=()
127
122
  ) -> jnp.ndarray:
123
+ if rngstate is None:
124
+ rngstate = RandomMarkovState(jax.random.PRNGKey(42))
128
125
  if priors is None:
129
126
  rngstate, newrngs = rngstate.get_random_key()
130
127
  samples = self.get_initial_samples(num_images, newrngs, start_step)
@@ -169,4 +166,5 @@ class DiffusionSampler():
169
166
  if self.autoencoder is not None:
170
167
  samples = self.autoencoder.decode(samples)
171
168
  samples = clip_images(samples)
172
- return samples
169
+ return samples
170
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.37.1
3
+ Version: 0.1.37.3
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -7,6 +7,8 @@ flaxdiff/data/online_loader.py,sha256=1Fi_QRixxRzbt602nORINcDeHEccvCrBpagrz4PURY
7
7
  flaxdiff/data/sources/gcs.py,sha256=11ZuQhvMyJRLg21DgVdzO5qEuae7zgzTXGNOskF-cbs,3380
8
8
  flaxdiff/data/sources/tfds.py,sha256=7n-uobG_UvkD5mU_1ovPd9kb6xJrbEKFFXdVEHDunts,2781
9
9
  flaxdiff/metrics/inception.py,sha256=a5kjMCPMT9gB88c_HCKiek-2vsAyoE35K7nDt4h4pVI,31843
10
+ flaxdiff/metrics/psnr.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ flaxdiff/metrics/ssim.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
12
  flaxdiff/metrics/utils.py,sha256=YuuOfqvqgIjsceupwNeJ59vQ2TnGeNMIyKdkIqOmoNg,1702
11
13
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
12
14
  flaxdiff/models/attention.py,sha256=JvrP7-09MV6IfRLRBhqjPmNUU-lkEMk9TOnJSBKcar8,13289
@@ -19,8 +21,8 @@ flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEK
19
21
  flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
20
22
  flaxdiff/models/autoencoder/simple_autoenc.py,sha256=UXHPgDmwGTnv3Uts6Zj3p9R9nJXnEiEXbllgarwDfXM,805
21
23
  flaxdiff/predictors/__init__.py,sha256=SKkYYRF9Wfgk2zhtZw4vCXOdOeRlrm2Mk6cvuaEvAzc,4403
22
- flaxdiff/samplers/__init__.py,sha256=dsK_a6HwMpxNuDbl5SzmvKJHYC_DGbTxxS3gfQYwcUY,166
23
- flaxdiff/samplers/common.py,sha256=8SabB-lgcbJZtNBfS96cn2xakeOj6-HC0Evhzeh2NlY,8921
24
+ flaxdiff/samplers/__init__.py,sha256=EY9v1pgwEoR64Kiz9K8fAR-4_ir9c03mYeY3hrpUNhE,308
25
+ flaxdiff/samplers/common.py,sha256=7gKNY4mWVnLjtcioGLFD_Vwmxg9zJovUb8EcYWlc_GE,8833
24
26
  flaxdiff/samplers/ddim.py,sha256=hTjDm0SmIj-Tkc80QRATMcN_sKVhHbqZQboRQCAn4mY,569
25
27
  flaxdiff/samplers/ddpm.py,sha256=JgkNSo7fp7Jm-8rCy4eu5m4YIzXTWzxv-iHf3EQ0z5w,2243
26
28
  flaxdiff/samplers/euler.py,sha256=QSkttB4DYnepDGwhWq3EGXYjMAqj4qLOdh7u98HttzY,2791
@@ -41,7 +43,7 @@ flaxdiff/trainer/autoencoder_trainer.py,sha256=hxihkRL9WCIQVGOP-pc1jjjIUaRXDLcNo
41
43
  flaxdiff/trainer/diffusion_trainer.py,sha256=KVeXJ9ZQKcvD-O_hCJnxro0dQRuQe5ZVGGMEL4Lgm9k,12814
42
44
  flaxdiff/trainer/simple_trainer.py,sha256=lmRo8N0bMupIyS3ejPvPtxoskY_3GLC8iyJE6u4TIWc,21990
43
45
  flaxdiff/trainer/video_diffusion_trainer.py,sha256=gMkKpnKNTo8QhTx5ptEEkc7W5-7rzXIr9queU53hXyQ,2197
44
- flaxdiff-0.1.37.1.dist-info/METADATA,sha256=ThP8nRYIdfvEE2ytSgTbulqbCDBLaWk22t4u2Mk-7-8,23985
45
- flaxdiff-0.1.37.1.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
46
- flaxdiff-0.1.37.1.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
47
- flaxdiff-0.1.37.1.dist-info/RECORD,,
46
+ flaxdiff-0.1.37.3.dist-info/METADATA,sha256=7U7SINGO_ZzsUeuTPi2LTMqxrj93Cvglyh1Q7D39zRM,23985
47
+ flaxdiff-0.1.37.3.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
48
+ flaxdiff-0.1.37.3.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
49
+ flaxdiff-0.1.37.3.dist-info/RECORD,,