flaxdiff 0.1.37.2__tar.gz → 0.1.37.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/PKG-INFO +1 -1
- flaxdiff-0.1.37.3/flaxdiff/metrics/psnr.py +0 -0
- flaxdiff-0.1.37.3/flaxdiff/metrics/ssim.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/samplers/common.py +7 -9
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff.egg-info/PKG-INFO +1 -1
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff.egg-info/SOURCES.txt +2 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/pyproject.toml +1 -1
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/README.md +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/__init__.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/data/__init__.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/data/dataset_map.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/data/datasets.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/data/online_loader.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/data/sources/gcs.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/data/sources/tfds.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/metrics/inception.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/metrics/utils.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/models/__init__.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/models/attention.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/models/autoencoder/__init__.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/models/autoencoder/diffusers.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/models/common.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/models/favor_fastattn.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/models/simple_unet.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/models/simple_vit.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/predictors/__init__.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/samplers/__init__.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/samplers/ddim.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/samplers/ddpm.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/samplers/euler.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/samplers/heun_sampler.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/samplers/multistep_dpm.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/samplers/rk4_sampler.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/schedulers/__init__.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/schedulers/common.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/schedulers/continuous.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/schedulers/cosine.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/schedulers/discrete.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/schedulers/exp.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/schedulers/karras.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/schedulers/linear.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/schedulers/sqrt.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/trainer/__init__.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/trainer/diffusion_trainer.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/trainer/simple_trainer.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/trainer/video_diffusion_trainer.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff/utils.py +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff.egg-info/dependency_links.txt +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff.egg-info/requires.txt +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/flaxdiff.egg-info/top_level.txt +0 -0
- {flaxdiff-0.1.37.2 → flaxdiff-0.1.37.3}/setup.cfg +0 -0
File without changes
|
File without changes
|
@@ -2,20 +2,15 @@ from flax import linen as nn
|
|
2
2
|
import jax
|
3
3
|
import jax.numpy as jnp
|
4
4
|
import tqdm
|
5
|
-
from typing import Union
|
5
|
+
from typing import Union, Type
|
6
6
|
from ..schedulers import NoiseScheduler
|
7
7
|
from ..utils import RandomMarkovState, MarkovState, clip_images
|
8
8
|
from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
|
9
9
|
|
10
10
|
class DiffusionSampler():
|
11
|
-
model:nn.Module
|
12
|
-
noise_schedule:NoiseScheduler
|
13
|
-
params:dict
|
14
|
-
model_output_transform:DiffusionPredictionTransform
|
15
|
-
|
16
11
|
def __init__(self, model:nn.Module, params:dict,
|
17
12
|
noise_schedule:NoiseScheduler,
|
18
|
-
model_output_transform:DiffusionPredictionTransform
|
13
|
+
model_output_transform:DiffusionPredictionTransform,
|
19
14
|
guidance_scale:float = 0.0,
|
20
15
|
null_labels_seq:jax.Array=None,
|
21
16
|
autoencoder=None,
|
@@ -122,9 +117,11 @@ class DiffusionSampler():
|
|
122
117
|
end_step:int = 0,
|
123
118
|
steps_override=None,
|
124
119
|
priors=None,
|
125
|
-
rngstate:RandomMarkovState=
|
120
|
+
rngstate:RandomMarkovState=None,
|
126
121
|
model_conditioning_inputs:tuple=()
|
127
122
|
) -> jnp.ndarray:
|
123
|
+
if rngstate is None:
|
124
|
+
rngstate = RandomMarkovState(jax.random.PRNGKey(42))
|
128
125
|
if priors is None:
|
129
126
|
rngstate, newrngs = rngstate.get_random_key()
|
130
127
|
samples = self.get_initial_samples(num_images, newrngs, start_step)
|
@@ -169,4 +166,5 @@ class DiffusionSampler():
|
|
169
166
|
if self.autoencoder is not None:
|
170
167
|
samples = self.autoencoder.decode(samples)
|
171
168
|
samples = clip_images(samples)
|
172
|
-
return samples
|
169
|
+
return samples
|
170
|
+
|
@@ -14,6 +14,8 @@ flaxdiff/data/online_loader.py
|
|
14
14
|
flaxdiff/data/sources/gcs.py
|
15
15
|
flaxdiff/data/sources/tfds.py
|
16
16
|
flaxdiff/metrics/inception.py
|
17
|
+
flaxdiff/metrics/psnr.py
|
18
|
+
flaxdiff/metrics/ssim.py
|
17
19
|
flaxdiff/metrics/utils.py
|
18
20
|
flaxdiff/models/__init__.py
|
19
21
|
flaxdiff/models/attention.py
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|