flaxdiff 0.1.35.5__py3-none-any.whl → 0.1.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/dataset_map.py +71 -0
- flaxdiff/data/datasets.py +169 -0
- flaxdiff/data/online_loader.py +69 -42
- flaxdiff/models/attention.py +1 -0
- flaxdiff/models/simple_unet.py +11 -11
- flaxdiff/models/simple_vit.py +1 -1
- flaxdiff/samplers/common.py +72 -20
- flaxdiff/samplers/ddim.py +5 -5
- flaxdiff/samplers/ddpm.py +5 -11
- flaxdiff/samplers/euler.py +7 -10
- flaxdiff/samplers/heun_sampler.py +3 -4
- flaxdiff/samplers/multistep_dpm.py +2 -3
- flaxdiff/samplers/rk4_sampler.py +9 -9
- flaxdiff/trainer/autoencoder_trainer.py +1 -1
- flaxdiff/trainer/diffusion_trainer.py +124 -32
- flaxdiff/trainer/simple_trainer.py +187 -91
- flaxdiff/trainer/video_diffusion_trainer.py +62 -0
- flaxdiff/utils.py +105 -2
- {flaxdiff-0.1.35.5.dist-info → flaxdiff-0.1.36.dist-info}/METADATA +11 -5
- flaxdiff-0.1.36.dist-info/RECORD +43 -0
- {flaxdiff-0.1.35.5.dist-info → flaxdiff-0.1.36.dist-info}/WHEEL +1 -1
- flaxdiff-0.1.35.5.dist-info/RECORD +0 -40
- {flaxdiff-0.1.35.5.dist-info → flaxdiff-0.1.36.dist-info}/top_level.txt +0 -0
flaxdiff/samplers/ddpm.py
CHANGED
@@ -3,9 +3,8 @@ import jax.numpy as jnp
|
|
3
3
|
from .common import DiffusionSampler
|
4
4
|
from ..utils import MarkovState, RandomMarkovState
|
5
5
|
class DDPMSampler(DiffusionSampler):
|
6
|
-
def take_next_step(self,
|
7
|
-
|
8
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
|
6
|
+
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
7
|
+
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
9
8
|
mean = self.noise_schedule.get_posterior_mean(reconstructed_samples, current_samples, current_step)
|
10
9
|
variance = self.noise_schedule.get_posterior_variance(steps=current_step)
|
11
10
|
|
@@ -19,9 +18,8 @@ class DDPMSampler(DiffusionSampler):
|
|
19
18
|
return super().generate_images(num_images=num_images, diffusion_steps=diffusion_steps, start_step=start_step, *args, **kwargs)
|
20
19
|
|
21
20
|
class SimpleDDPMSampler(DiffusionSampler):
|
22
|
-
def take_next_step(self,
|
23
|
-
|
24
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
|
21
|
+
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
22
|
+
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
25
23
|
state, rng = state.get_random_key()
|
26
24
|
noise = jax.random.normal(rng, reconstructed_samples.shape, dtype=jnp.float32)
|
27
25
|
|
@@ -33,11 +31,7 @@ class SimpleDDPMSampler(DiffusionSampler):
|
|
33
31
|
|
34
32
|
noise_ratio_squared = (next_noise_rate ** 2) / (current_noise_rate ** 2)
|
35
33
|
signal_ratio_squared = (current_signal_rate ** 2) / (next_signal_rate ** 2)
|
36
|
-
|
37
|
-
gamma = jnp.sqrt(noise_ratio_squared * betas)
|
34
|
+
gamma = jnp.sqrt(noise_ratio_squared * (1 - signal_ratio_squared))
|
38
35
|
|
39
36
|
next_samples = next_signal_rate * reconstructed_samples + pred_noise_coeff * pred_noise + noise * gamma
|
40
|
-
# pred_noise_coeff = ((next_noise_rate ** 2) * current_signal_rate) / (current_noise_rate * next_signal_rate)
|
41
|
-
# next_samples = (2 - jnp.sqrt(1 - betas)) * current_samples - betas * (pred_noise / current_noise_rate) + noise * gamma#jnp.sqrt(betas)
|
42
|
-
# next_samples = (1 / (jnp.sqrt(1 - betas) + 1.e-24)) * (current_samples - betas * (pred_noise / current_noise_rate)) + noise * gamma
|
43
37
|
return next_samples, state
|
flaxdiff/samplers/euler.py
CHANGED
@@ -5,9 +5,8 @@ from ..utils import RandomMarkovState
|
|
5
5
|
|
6
6
|
class EulerSampler(DiffusionSampler):
|
7
7
|
# Basically a DDIM Sampler but parameterized as an ODE
|
8
|
-
def take_next_step(self,
|
9
|
-
|
10
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
|
8
|
+
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
9
|
+
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
11
10
|
current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
|
12
11
|
next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
|
13
12
|
|
@@ -22,9 +21,8 @@ class SimplifiedEulerSampler(DiffusionSampler):
|
|
22
21
|
"""
|
23
22
|
This is for networks with forward diffusion of the form x_{t+1} = x_t + sigma_t * epsilon_t
|
24
23
|
"""
|
25
|
-
def take_next_step(self,
|
26
|
-
|
27
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
|
24
|
+
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
25
|
+
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
28
26
|
_, current_sigma = self.noise_schedule.get_rates(current_step)
|
29
27
|
_, next_sigma = self.noise_schedule.get_rates(next_step)
|
30
28
|
|
@@ -38,9 +36,8 @@ class EulerAncestralSampler(DiffusionSampler):
|
|
38
36
|
"""
|
39
37
|
Similar to EulerSampler but with ancestral sampling
|
40
38
|
"""
|
41
|
-
def take_next_step(self,
|
42
|
-
|
43
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
|
39
|
+
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
40
|
+
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
44
41
|
current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
|
45
42
|
next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
|
46
43
|
|
@@ -56,4 +53,4 @@ class EulerAncestralSampler(DiffusionSampler):
|
|
56
53
|
dW = jax.random.normal(subkey, current_samples.shape) * sigma_up
|
57
54
|
|
58
55
|
next_samples = current_samples + dx * dt + dW
|
59
|
-
return next_samples, state
|
56
|
+
return next_samples, state
|
@@ -4,9 +4,8 @@ from .common import DiffusionSampler
|
|
4
4
|
from ..utils import RandomMarkovState
|
5
5
|
|
6
6
|
class HeunSampler(DiffusionSampler):
|
7
|
-
def take_next_step(self,
|
8
|
-
|
9
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
|
7
|
+
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
8
|
+
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
10
9
|
# Get the noise and signal rates for the current and next steps
|
11
10
|
current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
|
12
11
|
next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
|
@@ -18,7 +17,7 @@ class HeunSampler(DiffusionSampler):
|
|
18
17
|
next_samples_0 = current_samples + dx_0 * dt
|
19
18
|
|
20
19
|
# Recompute x_0 and eps at the first estimate to refine the derivative
|
21
|
-
estimated_x_0, _, _ = self.sample_model(next_samples_0, next_step)
|
20
|
+
estimated_x_0, _, _ = self.sample_model(next_samples_0, next_step, *model_conditioning_inputs)
|
22
21
|
|
23
22
|
# Estimate the refined derivative using the midpoint (Heun's method)
|
24
23
|
dx_1 = (next_samples_0 - x_0_coeff * estimated_x_0) / next_sigma
|
@@ -8,9 +8,8 @@ class MultiStepDPM(DiffusionSampler):
|
|
8
8
|
super().__init__(*args, **kwargs)
|
9
9
|
self.history = []
|
10
10
|
|
11
|
-
def
|
12
|
-
|
13
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
|
11
|
+
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
12
|
+
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
14
13
|
# Get the noise and signal rates for the current and next steps
|
15
14
|
current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
|
16
15
|
next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
|
flaxdiff/samplers/rk4_sampler.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import jax
|
2
2
|
import jax.numpy as jnp
|
3
3
|
from .common import DiffusionSampler
|
4
|
-
from ..utils import RandomMarkovState
|
4
|
+
from ..utils import RandomMarkovState, MarkovState
|
5
5
|
from ..schedulers import GeneralizedNoiseScheduler
|
6
6
|
|
7
7
|
class RK4Sampler(DiffusionSampler):
|
@@ -9,14 +9,14 @@ class RK4Sampler(DiffusionSampler):
|
|
9
9
|
super().__init__(*args, **kwargs)
|
10
10
|
assert issubclass(type(self.noise_schedule), GeneralizedNoiseScheduler), "Noise schedule must be a GeneralizedNoiseScheduler"
|
11
11
|
@jax.jit
|
12
|
-
def get_derivative(x_t, sigma, state:RandomMarkovState) -> tuple[jnp.ndarray, RandomMarkovState]:
|
12
|
+
def get_derivative(x_t, sigma, state:RandomMarkovState, model_conditioning_inputs) -> tuple[jnp.ndarray, RandomMarkovState]:
|
13
13
|
t = self.noise_schedule.get_timesteps(sigma)
|
14
|
-
x_0, eps, _ = self.sample_model(x_t, t)
|
14
|
+
x_0, eps, _ = self.sample_model(x_t, t, *model_conditioning_inputs)
|
15
15
|
return eps, state
|
16
16
|
|
17
17
|
self.get_derivative = get_derivative
|
18
18
|
|
19
|
-
def sample_step(self, current_samples:jnp.ndarray, current_step, next_step, state:
|
19
|
+
def sample_step(self, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
|
20
20
|
step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32)
|
21
21
|
current_step = step_ones * current_step
|
22
22
|
next_step = step_ones * next_step
|
@@ -25,10 +25,10 @@ class RK4Sampler(DiffusionSampler):
|
|
25
25
|
|
26
26
|
dt = next_sigma - current_sigma
|
27
27
|
|
28
|
-
k1, state = self.get_derivative(current_samples, current_sigma, state)
|
29
|
-
k2, state = self.get_derivative(current_samples + 0.5 * k1 * dt, current_sigma + 0.5 * dt, state)
|
30
|
-
k3, state = self.get_derivative(current_samples + 0.5 * k2 * dt, current_sigma + 0.5 * dt, state)
|
31
|
-
k4, state = self.get_derivative(current_samples + k3 * dt, current_sigma + dt, state)
|
28
|
+
k1, state = self.get_derivative(current_samples, current_sigma, state, model_conditioning_inputs)
|
29
|
+
k2, state = self.get_derivative(current_samples + 0.5 * k1 * dt, current_sigma + 0.5 * dt, state, model_conditioning_inputs)
|
30
|
+
k3, state = self.get_derivative(current_samples + 0.5 * k2 * dt, current_sigma + 0.5 * dt, state, model_conditioning_inputs)
|
31
|
+
k4, state = self.get_derivative(current_samples + k3 * dt, current_sigma + dt, state, model_conditioning_inputs)
|
32
32
|
|
33
|
-
next_samples = current_samples + ((k1 + 2 * k2 + 2 * k3 + k4) / 6)
|
33
|
+
next_samples = current_samples + (((k1 + 2 * k2 + 2 * k3 + k4) * dt) / 6)
|
34
34
|
return next_samples, state
|
@@ -14,7 +14,7 @@ from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransfor
|
|
14
14
|
from flaxdiff.utils import RandomMarkovState
|
15
15
|
|
16
16
|
from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
|
17
|
-
|
17
|
+
from .diffusion_trainer import TrainState
|
18
18
|
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
|
19
19
|
|
20
20
|
class AutoEncoderTrainer(SimpleTrainer):
|
@@ -1,22 +1,27 @@
|
|
1
|
+
import flax
|
1
2
|
from flax import linen as nn
|
2
3
|
import jax
|
3
4
|
from typing import Callable
|
4
5
|
from dataclasses import field
|
5
6
|
import jax.numpy as jnp
|
7
|
+
import traceback
|
6
8
|
import optax
|
9
|
+
import functools
|
7
10
|
from jax.sharding import Mesh, PartitionSpec as P
|
8
11
|
from jax.experimental.shard_map import shard_map
|
9
|
-
from typing import Dict, Callable, Sequence, Any, Union, Tuple
|
12
|
+
from typing import Dict, Callable, Sequence, Any, Union, Tuple, Type
|
10
13
|
|
11
14
|
from ..schedulers import NoiseScheduler
|
12
15
|
from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
|
16
|
+
from ..samplers.common import DiffusionSampler
|
13
17
|
|
14
18
|
from flaxdiff.utils import RandomMarkovState
|
15
19
|
|
16
20
|
from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
|
17
21
|
|
18
22
|
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
|
19
|
-
from flax.training
|
23
|
+
from flax.training import dynamic_scale as dynamic_scale_lib
|
24
|
+
from flaxdiff.utils import TextEncoder, ConditioningEncoder
|
20
25
|
|
21
26
|
class TrainState(SimpleTrainState):
|
22
27
|
rngs: jax.random.PRNGKey
|
@@ -47,6 +52,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
47
52
|
name: str = "Diffusion",
|
48
53
|
model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
|
49
54
|
autoencoder: AutoEncoder = None,
|
55
|
+
encoder: ConditioningEncoder = None,
|
50
56
|
**kwargs
|
51
57
|
):
|
52
58
|
super().__init__(
|
@@ -62,6 +68,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
62
68
|
self.unconditional_prob = unconditional_prob
|
63
69
|
|
64
70
|
self.autoencoder = autoencoder
|
71
|
+
self.encoder = encoder
|
65
72
|
|
66
73
|
def generate_states(
|
67
74
|
self,
|
@@ -84,8 +91,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
84
91
|
new_state = existing_state
|
85
92
|
|
86
93
|
if param_transforms is not None:
|
87
|
-
|
88
|
-
new_state['ema_params'] = param_transforms(new_state['ema_params'])
|
94
|
+
params = param_transforms(params)
|
89
95
|
|
90
96
|
state = TrainState.create(
|
91
97
|
apply_fn=model.apply,
|
@@ -94,7 +100,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
94
100
|
tx=optimizer,
|
95
101
|
rngs=rngs,
|
96
102
|
metrics=Metrics.empty(),
|
97
|
-
dynamic_scale = DynamicScale() if use_dynamic_scale else None
|
103
|
+
dynamic_scale = dynamic_scale_lib.DynamicScale() if use_dynamic_scale else None
|
98
104
|
)
|
99
105
|
|
100
106
|
if existing_best_state is not None:
|
@@ -105,7 +111,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
105
111
|
|
106
112
|
return state, best_state
|
107
113
|
|
108
|
-
def _define_train_step(self, batch_size
|
114
|
+
def _define_train_step(self, batch_size):
|
109
115
|
noise_schedule: NoiseScheduler = self.noise_schedule
|
110
116
|
model = self.model
|
111
117
|
model_output_transform = self.model_output_transform
|
@@ -114,6 +120,11 @@ class DiffusionTrainer(SimpleTrainer):
|
|
114
120
|
|
115
121
|
# Determine the number of unconditional samples
|
116
122
|
num_unconditional = int(batch_size * unconditional_prob)
|
123
|
+
|
124
|
+
null_labels_full = self.encoder([""])
|
125
|
+
null_labels_seq = jnp.array(null_labels_full[0], dtype=jnp.float16)
|
126
|
+
|
127
|
+
conditioning_encoder = self.encoder
|
117
128
|
|
118
129
|
nS, nC = null_labels_seq.shape
|
119
130
|
null_labels_seq = jnp.broadcast_to(
|
@@ -131,6 +142,11 @@ class DiffusionTrainer(SimpleTrainer):
|
|
131
142
|
local_rng_state = RandomMarkovState(subkey)
|
132
143
|
|
133
144
|
images = batch['image']
|
145
|
+
|
146
|
+
# First get the standard deviation of the images
|
147
|
+
# std = jnp.std(images, axis=(1, 2, 3))
|
148
|
+
# is_non_zero = (std > 0)
|
149
|
+
|
134
150
|
images = jnp.array(images, dtype=jnp.float32)
|
135
151
|
# normalize image
|
136
152
|
images = (images - 127.5) / 127.5
|
@@ -140,9 +156,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
140
156
|
local_rng_state, rngs = local_rng_state.get_random_key()
|
141
157
|
images = autoencoder.encode(images, rngs)
|
142
158
|
|
143
|
-
|
144
|
-
input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
|
145
|
-
label_seq = output.last_hidden_state
|
159
|
+
label_seq = conditioning_encoder.encode_from_tokens(batch)
|
146
160
|
|
147
161
|
# Generate random probabilities to decide how much of this batch will be unconditional
|
148
162
|
|
@@ -163,8 +177,11 @@ class DiffusionTrainer(SimpleTrainer):
|
|
163
177
|
preds = model_output_transform.pred_transform(
|
164
178
|
noisy_images, preds, rates)
|
165
179
|
nloss = loss_fn(preds, expected_output)
|
166
|
-
#
|
180
|
+
# Ignore the loss contribution of images with zero standard deviation
|
167
181
|
nloss *= noise_schedule.get_weights(noise_level)
|
182
|
+
# nloss = jnp.mean(nloss, axis=(1,2,3))
|
183
|
+
# nloss = jnp.where(is_non_zero, nloss, 0)
|
184
|
+
# nloss = jnp.mean(nloss, where=nloss != 0)
|
168
185
|
nloss = jnp.mean(nloss)
|
169
186
|
loss = nloss
|
170
187
|
return loss
|
@@ -185,11 +202,11 @@ class DiffusionTrainer(SimpleTrainer):
|
|
185
202
|
|
186
203
|
new_state = train_state.apply_gradients(grads=grads)
|
187
204
|
|
188
|
-
if train_state.dynamic_scale:
|
205
|
+
if train_state.dynamic_scale is not None:
|
189
206
|
# if is_fin == False the gradients contain Inf/NaNs and optimizer state and
|
190
207
|
# params should be restored (= skip this step).
|
191
208
|
select_fn = functools.partial(jnp.where, is_fin)
|
192
|
-
new_state =
|
209
|
+
new_state = new_state.replace(
|
193
210
|
opt_state=jax.tree_util.tree_map(
|
194
211
|
select_fn, new_state.opt_state, train_state.opt_state
|
195
212
|
),
|
@@ -211,24 +228,99 @@ class DiffusionTrainer(SimpleTrainer):
|
|
211
228
|
|
212
229
|
return train_step
|
213
230
|
|
214
|
-
def
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
231
|
+
def _define_vaidation_step(self, sampler_class: Type[DiffusionSampler]):
|
232
|
+
model = self.model
|
233
|
+
encoder = self.encoder
|
234
|
+
autoencoder = self.autoencoder
|
235
|
+
|
236
|
+
null_labels_full = encoder([""])
|
237
|
+
null_labels_full = null_labels_full.astype(jnp.float16)
|
238
|
+
# null_labels_seq = jnp.array(null_labels_full[0], dtype=jnp.float16)
|
239
|
+
|
240
|
+
def generate_sampler(state: TrainState):
|
241
|
+
sampler = sampler_class(
|
242
|
+
model=model,
|
243
|
+
params=state.ema_params,
|
244
|
+
noise_schedule=self.noise_schedule,
|
245
|
+
model_output_transform=self.model_output_transform,
|
246
|
+
image_size=self.input_shapes['x'][0],
|
247
|
+
null_labels_seq=null_labels_full,
|
248
|
+
autoencoder=autoencoder,
|
249
|
+
)
|
250
|
+
return sampler
|
251
|
+
|
252
|
+
def generate_samples(
|
253
|
+
batch,
|
254
|
+
sampler: DiffusionSampler,
|
255
|
+
diffusion_steps: int,
|
256
|
+
):
|
257
|
+
labels_seq = encoder.encode_from_tokens(batch)
|
258
|
+
labels_seq = jnp.array(labels_seq, dtype=jnp.float16)
|
259
|
+
samples = sampler.generate_images(
|
260
|
+
num_images=len(labels_seq),
|
261
|
+
diffusion_steps=diffusion_steps,
|
262
|
+
start_step=1000,
|
263
|
+
end_step=0,
|
264
|
+
priors=None,
|
265
|
+
model_conditioning_inputs=(labels_seq,),
|
266
|
+
)
|
267
|
+
return samples
|
268
|
+
|
269
|
+
return generate_sampler, generate_samples
|
270
|
+
|
271
|
+
def validation_loop(
|
272
|
+
self,
|
273
|
+
val_state: SimpleTrainState,
|
274
|
+
val_step_fn: Callable,
|
275
|
+
val_ds,
|
276
|
+
val_steps_per_epoch,
|
277
|
+
current_step,
|
278
|
+
diffusion_steps=200,
|
279
|
+
):
|
280
|
+
generate_sampler, generate_samples = val_step_fn
|
281
|
+
|
282
|
+
sampler = generate_sampler(val_state)
|
283
|
+
|
284
|
+
val_ds = iter(val_ds()) if val_ds else None
|
285
|
+
# Evaluation step
|
286
|
+
try:
|
287
|
+
samples = generate_samples(
|
288
|
+
next(val_ds),
|
289
|
+
sampler,
|
290
|
+
diffusion_steps,
|
291
|
+
)
|
292
|
+
|
293
|
+
# Put each sample on wandb
|
294
|
+
if self.wandb:
|
295
|
+
import numpy as np
|
296
|
+
from wandb import Image as wandbImage
|
297
|
+
wandb_images = []
|
298
|
+
for i in range(samples.shape[0]):
|
299
|
+
# convert the sample to numpy
|
300
|
+
sample = np.array(samples[i])
|
301
|
+
# denormalize the image
|
302
|
+
sample = (sample + 1) * 127.5
|
303
|
+
sample = np.clip(sample, 0, 255).astype(np.uint8)
|
304
|
+
# add the image to the list
|
305
|
+
wandb_images.append(sample)
|
306
|
+
# log the images to wandb
|
307
|
+
self.wandb.log({
|
308
|
+
f"sample_{i}": wandbImage(sample, caption=f"Sample {i} at step {current_step}")
|
309
|
+
}, step=current_step)
|
310
|
+
except Exception as e:
|
311
|
+
print("Error logging images to wandb", e)
|
312
|
+
traceback.print_exc()
|
313
|
+
|
314
|
+
def fit(self, data, training_steps_per_epoch, epochs, val_steps_per_epoch=8, sampler_class=None):
|
226
315
|
local_batch_size = data['local_batch_size']
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
316
|
+
validation_step_args = {
|
317
|
+
"sampler_class": sampler_class,
|
318
|
+
}
|
319
|
+
super().fit(
|
320
|
+
data,
|
321
|
+
train_steps_per_epoch=training_steps_per_epoch,
|
322
|
+
epochs=epochs,
|
323
|
+
train_step_args={"batch_size": local_batch_size},
|
324
|
+
val_steps_per_epoch=val_steps_per_epoch,
|
325
|
+
validation_step_args=validation_step_args,
|
326
|
+
)
|