flaxdiff 0.1.36.5__tar.gz → 0.1.37.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/PKG-INFO +1 -1
- flaxdiff-0.1.37.1/flaxdiff/samplers/__init__.py +7 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/samplers/common.py +20 -13
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/samplers/ddim.py +1 -1
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/samplers/ddpm.py +2 -2
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/samplers/euler.py +3 -3
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/samplers/heun_sampler.py +2 -2
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/samplers/multistep_dpm.py +1 -1
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/samplers/rk4_sampler.py +7 -7
- flaxdiff-0.1.37.1/flaxdiff/schedulers/__init__.py +6 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/trainer/diffusion_trainer.py +16 -14
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff.egg-info/PKG-INFO +1 -1
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/pyproject.toml +1 -1
- flaxdiff-0.1.36.5/flaxdiff/samplers/__init__.py +0 -7
- flaxdiff-0.1.36.5/flaxdiff/schedulers/__init__.py +0 -6
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/README.md +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/__init__.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/data/__init__.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/data/dataset_map.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/data/datasets.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/data/online_loader.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/data/sources/gcs.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/data/sources/tfds.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/metrics/inception.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/metrics/utils.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/models/__init__.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/models/attention.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/models/autoencoder/__init__.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/models/autoencoder/diffusers.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/models/common.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/models/favor_fastattn.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/models/simple_unet.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/models/simple_vit.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/predictors/__init__.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/schedulers/common.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/schedulers/continuous.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/schedulers/cosine.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/schedulers/discrete.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/schedulers/exp.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/schedulers/karras.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/schedulers/linear.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/schedulers/sqrt.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/trainer/__init__.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/trainer/simple_trainer.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/trainer/video_diffusion_trainer.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/utils.py +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff.egg-info/SOURCES.txt +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff.egg-info/dependency_links.txt +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff.egg-info/requires.txt +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff.egg-info/top_level.txt +0 -0
- {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/setup.cfg +0 -0
@@ -37,7 +37,7 @@ class DiffusionSampler():
|
|
37
37
|
# Classifier free guidance
|
38
38
|
assert null_labels_seq is not None, "Null labels sequence is required for classifier-free guidance"
|
39
39
|
print("Using classifier-free guidance")
|
40
|
-
def sample_model(x_t, t, *additional_inputs):
|
40
|
+
def sample_model(params, x_t, t, *additional_inputs):
|
41
41
|
# Concatenate unconditional and conditional inputs
|
42
42
|
x_t_cat = jnp.concatenate([x_t] * 2, axis=0)
|
43
43
|
t_cat = jnp.concatenate([t] * 2, axis=0)
|
@@ -46,7 +46,7 @@ class DiffusionSampler():
|
|
46
46
|
|
47
47
|
text_labels_seq, = additional_inputs
|
48
48
|
text_labels_seq = jnp.concatenate([text_labels_seq, jnp.broadcast_to(null_labels_seq, text_labels_seq.shape)], axis=0)
|
49
|
-
model_output = self.model.apply(
|
49
|
+
model_output = self.model.apply(params, *self.noise_schedule.transform_inputs(x_t_cat * c_in_cat, t_cat), text_labels_seq)
|
50
50
|
# Split model output into unconditional and conditional parts
|
51
51
|
model_output_cond, model_output_uncond = jnp.split(model_output, 2, axis=0)
|
52
52
|
model_output = model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond)
|
@@ -55,10 +55,10 @@ class DiffusionSampler():
|
|
55
55
|
return x_0, eps, model_output
|
56
56
|
else:
|
57
57
|
# Unconditional sampling
|
58
|
-
def sample_model(x_t, t, *additional_inputs):
|
58
|
+
def sample_model(params, x_t, t, *additional_inputs):
|
59
59
|
rates = self.noise_schedule.get_rates(t)
|
60
60
|
c_in = self.model_output_transform.get_input_scale(rates)
|
61
|
-
model_output = self.model.apply(
|
61
|
+
model_output = self.model.apply(params, *self.noise_schedule.transform_inputs(x_t * c_in, t), *additional_inputs)
|
62
62
|
x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule)
|
63
63
|
return x_0, eps, model_output
|
64
64
|
|
@@ -70,22 +70,23 @@ class DiffusionSampler():
|
|
70
70
|
self.sample_model = sample_model
|
71
71
|
|
72
72
|
# Used to sample from the diffusion model
|
73
|
-
def sample_step(self, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
|
73
|
+
def sample_step(self, sample_model_fn, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
|
74
74
|
# First clip the noisy images
|
75
75
|
step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32)
|
76
76
|
current_step = step_ones * current_step
|
77
77
|
next_step = step_ones * next_step
|
78
|
-
pred_images, pred_noise, _ =
|
78
|
+
pred_images, pred_noise, _ = sample_model_fn(current_samples, current_step, *model_conditioning_inputs)
|
79
79
|
# plotImages(pred_images)
|
80
80
|
# pred_images = clip_images(pred_images)
|
81
81
|
new_samples, state = self.take_next_step(current_samples=current_samples, reconstructed_samples=pred_images,
|
82
|
-
|
83
|
-
|
82
|
+
pred_noise=pred_noise, current_step=current_step, next_step=next_step, state=state,
|
83
|
+
model_conditioning_inputs=model_conditioning_inputs,
|
84
|
+
sample_model_fn=sample_model_fn,
|
84
85
|
)
|
85
86
|
return new_samples, state
|
86
87
|
|
87
88
|
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
88
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
89
|
+
pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1,) -> tuple[jnp.ndarray, RandomMarkovState]:
|
89
90
|
# estimate the q(x_{t-1} | x_t, x_0).
|
90
91
|
# pred_images is x_0, noisy_images is x_t, steps is t
|
91
92
|
return NotImplementedError
|
@@ -114,6 +115,7 @@ class DiffusionSampler():
|
|
114
115
|
return jax.random.normal(rngs, (num_images, image_size, image_size, image_channels)) * variance
|
115
116
|
|
116
117
|
def generate_images(self,
|
118
|
+
params:dict=None,
|
117
119
|
num_images=16,
|
118
120
|
diffusion_steps=1000,
|
119
121
|
start_step:int = None,
|
@@ -131,10 +133,15 @@ class DiffusionSampler():
|
|
131
133
|
if self.autoencoder is not None:
|
132
134
|
priors = self.autoencoder.encode(priors)
|
133
135
|
samples = priors
|
136
|
+
|
137
|
+
params = params if params is not None else self.params
|
138
|
+
|
139
|
+
def sample_model_fn(x_t, t, *additional_inputs):
|
140
|
+
return self.sample_model(params, x_t, t, *additional_inputs)
|
134
141
|
|
135
142
|
# @jax.jit
|
136
|
-
def sample_step(state:RandomMarkovState, samples, current_step, next_step):
|
137
|
-
samples, state = self.sample_step(current_samples=samples,
|
143
|
+
def sample_step(sample_model_fn, state:RandomMarkovState, samples, current_step, next_step):
|
144
|
+
samples, state = self.sample_step(sample_model_fn=sample_model_fn, current_samples=samples,
|
138
145
|
current_step=current_step,
|
139
146
|
model_conditioning_inputs=model_conditioning_inputs,
|
140
147
|
state=state, next_step=next_step)
|
@@ -154,11 +161,11 @@ class DiffusionSampler():
|
|
154
161
|
next_step = self.scale_steps(steps[i+1] if i+1 < len(steps) else 0)
|
155
162
|
if i != len(steps) - 1:
|
156
163
|
# print("normal step")
|
157
|
-
samples, rngstate = sample_step(rngstate, samples, current_step, next_step)
|
164
|
+
samples, rngstate = sample_step(sample_model_fn, rngstate, samples, current_step, next_step)
|
158
165
|
else:
|
159
166
|
# print("last step")
|
160
167
|
step_ones = jnp.ones((num_images, ), dtype=jnp.int32)
|
161
|
-
samples, _, _ =
|
168
|
+
samples, _, _ = sample_model_fn(samples, current_step * step_ones, *model_conditioning_inputs)
|
162
169
|
if self.autoencoder is not None:
|
163
170
|
samples = self.autoencoder.decode(samples)
|
164
171
|
samples = clip_images(samples)
|
@@ -4,7 +4,7 @@ from ..utils import MarkovState, RandomMarkovState
|
|
4
4
|
|
5
5
|
class DDIMSampler(DiffusionSampler):
|
6
6
|
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
7
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
7
|
+
pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
8
8
|
next_signal_rate, next_noise_rate = self.noise_schedule.get_rates(next_step)
|
9
9
|
return reconstructed_samples * next_signal_rate + pred_noise * next_noise_rate, state
|
10
10
|
|
@@ -4,7 +4,7 @@ from .common import DiffusionSampler
|
|
4
4
|
from ..utils import MarkovState, RandomMarkovState
|
5
5
|
class DDPMSampler(DiffusionSampler):
|
6
6
|
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
7
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
7
|
+
pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
8
8
|
mean = self.noise_schedule.get_posterior_mean(reconstructed_samples, current_samples, current_step)
|
9
9
|
variance = self.noise_schedule.get_posterior_variance(steps=current_step)
|
10
10
|
|
@@ -19,7 +19,7 @@ class DDPMSampler(DiffusionSampler):
|
|
19
19
|
|
20
20
|
class SimpleDDPMSampler(DiffusionSampler):
|
21
21
|
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
22
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
22
|
+
pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
23
23
|
state, rng = state.get_random_key()
|
24
24
|
noise = jax.random.normal(rng, reconstructed_samples.shape, dtype=jnp.float32)
|
25
25
|
|
@@ -6,7 +6,7 @@ from ..utils import RandomMarkovState
|
|
6
6
|
class EulerSampler(DiffusionSampler):
|
7
7
|
# Basically a DDIM Sampler but parameterized as an ODE
|
8
8
|
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
9
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
9
|
+
pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
10
10
|
current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
|
11
11
|
next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
|
12
12
|
|
@@ -22,7 +22,7 @@ class SimplifiedEulerSampler(DiffusionSampler):
|
|
22
22
|
This is for networks with forward diffusion of the form x_{t+1} = x_t + sigma_t * epsilon_t
|
23
23
|
"""
|
24
24
|
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
25
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
25
|
+
pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
26
26
|
_, current_sigma = self.noise_schedule.get_rates(current_step)
|
27
27
|
_, next_sigma = self.noise_schedule.get_rates(next_step)
|
28
28
|
|
@@ -37,7 +37,7 @@ class EulerAncestralSampler(DiffusionSampler):
|
|
37
37
|
Similar to EulerSampler but with ancestral sampling
|
38
38
|
"""
|
39
39
|
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
40
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
40
|
+
pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
41
41
|
current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
|
42
42
|
next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
|
43
43
|
|
@@ -5,7 +5,7 @@ from ..utils import RandomMarkovState
|
|
5
5
|
|
6
6
|
class HeunSampler(DiffusionSampler):
|
7
7
|
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
8
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
8
|
+
pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
9
9
|
# Get the noise and signal rates for the current and next steps
|
10
10
|
current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
|
11
11
|
next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
|
@@ -17,7 +17,7 @@ class HeunSampler(DiffusionSampler):
|
|
17
17
|
next_samples_0 = current_samples + dx_0 * dt
|
18
18
|
|
19
19
|
# Recompute x_0 and eps at the first estimate to refine the derivative
|
20
|
-
estimated_x_0, _, _ =
|
20
|
+
estimated_x_0, _, _ = sample_model_fn(next_samples_0, next_step, *model_conditioning_inputs)
|
21
21
|
|
22
22
|
# Estimate the refined derivative using the midpoint (Heun's method)
|
23
23
|
dx_1 = (next_samples_0 - x_0_coeff * estimated_x_0) / next_sigma
|
@@ -9,7 +9,7 @@ class MultiStepDPM(DiffusionSampler):
|
|
9
9
|
self.history = []
|
10
10
|
|
11
11
|
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
12
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
12
|
+
pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
13
13
|
# Get the noise and signal rates for the current and next steps
|
14
14
|
current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
|
15
15
|
next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
|
@@ -9,14 +9,14 @@ class RK4Sampler(DiffusionSampler):
|
|
9
9
|
super().__init__(*args, **kwargs)
|
10
10
|
assert issubclass(type(self.noise_schedule), GeneralizedNoiseScheduler), "Noise schedule must be a GeneralizedNoiseScheduler"
|
11
11
|
@jax.jit
|
12
|
-
def get_derivative(x_t, sigma, state:RandomMarkovState, model_conditioning_inputs) -> tuple[jnp.ndarray, RandomMarkovState]:
|
12
|
+
def get_derivative(sample_model_fn, x_t, sigma, state:RandomMarkovState, model_conditioning_inputs) -> tuple[jnp.ndarray, RandomMarkovState]:
|
13
13
|
t = self.noise_schedule.get_timesteps(sigma)
|
14
|
-
x_0, eps, _ =
|
14
|
+
x_0, eps, _ = sample_model_fn(x_t, t, *model_conditioning_inputs)
|
15
15
|
return eps, state
|
16
16
|
|
17
17
|
self.get_derivative = get_derivative
|
18
18
|
|
19
|
-
def sample_step(self, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
|
19
|
+
def sample_step(self, sample_model_fn, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
|
20
20
|
step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32)
|
21
21
|
current_step = step_ones * current_step
|
22
22
|
next_step = step_ones * next_step
|
@@ -25,10 +25,10 @@ class RK4Sampler(DiffusionSampler):
|
|
25
25
|
|
26
26
|
dt = next_sigma - current_sigma
|
27
27
|
|
28
|
-
k1, state = self.get_derivative(current_samples, current_sigma, state, model_conditioning_inputs)
|
29
|
-
k2, state = self.get_derivative(current_samples + 0.5 * k1 * dt, current_sigma + 0.5 * dt, state, model_conditioning_inputs)
|
30
|
-
k3, state = self.get_derivative(current_samples + 0.5 * k2 * dt, current_sigma + 0.5 * dt, state, model_conditioning_inputs)
|
31
|
-
k4, state = self.get_derivative(current_samples + k3 * dt, current_sigma + dt, state, model_conditioning_inputs)
|
28
|
+
k1, state = self.get_derivative(sample_model_fn, current_samples, current_sigma, state, model_conditioning_inputs)
|
29
|
+
k2, state = self.get_derivative(sample_model_fn, current_samples + 0.5 * k1 * dt, current_sigma + 0.5 * dt, state, model_conditioning_inputs)
|
30
|
+
k3, state = self.get_derivative(sample_model_fn, current_samples + 0.5 * k2 * dt, current_sigma + 0.5 * dt, state, model_conditioning_inputs)
|
31
|
+
k4, state = self.get_derivative(sample_model_fn, current_samples + k3 * dt, current_sigma + dt, state, model_conditioning_inputs)
|
32
32
|
|
33
33
|
next_samples = current_samples + (((k1 + 2 * k2 + 2 * k3 + k4) * dt) / 6)
|
34
34
|
return next_samples, state
|
@@ -235,19 +235,19 @@ class DiffusionTrainer(SimpleTrainer):
|
|
235
235
|
null_labels_full = null_labels_full.astype(jnp.float16)
|
236
236
|
# null_labels_seq = jnp.array(null_labels_full[0], dtype=jnp.float16)
|
237
237
|
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
return sampler
|
238
|
+
sampler = sampler_class(
|
239
|
+
model=model,
|
240
|
+
params=None,
|
241
|
+
noise_schedule=self.noise_schedule if sampling_noise_schedule is None else sampling_noise_schedule,
|
242
|
+
model_output_transform=self.model_output_transform,
|
243
|
+
image_size=self.input_shapes['x'][0],
|
244
|
+
null_labels_seq=null_labels_full,
|
245
|
+
autoencoder=autoencoder,
|
246
|
+
guidance_scale=3.0,
|
247
|
+
)
|
249
248
|
|
250
249
|
def generate_samples(
|
250
|
+
val_state: TrainState,
|
251
251
|
batch,
|
252
252
|
sampler: DiffusionSampler,
|
253
253
|
diffusion_steps: int,
|
@@ -255,6 +255,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
255
255
|
labels_seq = encoder.encode_from_tokens(batch)
|
256
256
|
labels_seq = jnp.array(labels_seq, dtype=jnp.float16)
|
257
257
|
samples = sampler.generate_images(
|
258
|
+
params=val_state.ema_params,
|
258
259
|
num_images=len(labels_seq),
|
259
260
|
diffusion_steps=diffusion_steps,
|
260
261
|
start_step=1000,
|
@@ -264,7 +265,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
264
265
|
)
|
265
266
|
return samples
|
266
267
|
|
267
|
-
return
|
268
|
+
return sampler, generate_samples
|
268
269
|
|
269
270
|
def validation_loop(
|
270
271
|
self,
|
@@ -275,14 +276,15 @@ class DiffusionTrainer(SimpleTrainer):
|
|
275
276
|
current_step,
|
276
277
|
diffusion_steps=200,
|
277
278
|
):
|
278
|
-
|
279
|
+
sampler, generate_samples = val_step_fn
|
279
280
|
|
280
|
-
sampler = generate_sampler(val_state)
|
281
|
+
# sampler = generate_sampler(val_state)
|
281
282
|
|
282
283
|
val_ds = iter(val_ds()) if val_ds else None
|
283
284
|
# Evaluation step
|
284
285
|
try:
|
285
286
|
samples = generate_samples(
|
287
|
+
val_state,
|
286
288
|
next(val_ds),
|
287
289
|
sampler,
|
288
290
|
diffusion_steps,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
from .common import DiffusionSampler
|
2
|
-
from .ddim import DDIMSampler
|
3
|
-
from .ddpm import DDPMSampler, SimpleDDPMSampler
|
4
|
-
from .euler import EulerSampler, SimplifiedEulerSampler
|
5
|
-
from .heun_sampler import HeunSampler
|
6
|
-
from .rk4_sampler import RK4Sampler
|
7
|
-
from .multistep_dpm import MultiStepDPM
|
@@ -1,6 +0,0 @@
|
|
1
|
-
from .discrete import DiscreteNoiseScheduler
|
2
|
-
from .common import NoiseScheduler, GeneralizedNoiseScheduler
|
3
|
-
from .cosine import CosineNoiseScheduler, ContinuousNoiseScheduler, CosineGeneralNoiseScheduler
|
4
|
-
from .linear import LinearNoiseSchedule
|
5
|
-
from .sqrt import SqrtContinuousNoiseScheduler
|
6
|
-
from .karras import KarrasVENoiseScheduler, SimpleExpNoiseScheduler, EDMNoiseScheduler
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|