flaxdiff 0.1.36.1__py3-none-any.whl → 0.1.36.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data/__init__.py +1 -0
- data/dataset_map.py +71 -0
- data/datasets.py +169 -0
- data/online_loader.py +363 -0
- data/sources/gcs.py +81 -0
- data/sources/tfds.py +67 -0
- {flaxdiff-0.1.36.1.dist-info → flaxdiff-0.1.36.2.dist-info}/METADATA +1 -1
- flaxdiff-0.1.36.2.dist-info/RECORD +47 -0
- flaxdiff-0.1.36.2.dist-info/top_level.txt +9 -0
- metrics/inception.py +658 -0
- metrics/utils.py +49 -0
- models/__init__.py +1 -0
- models/attention.py +368 -0
- models/autoencoder/__init__.py +2 -0
- models/autoencoder/autoencoder.py +19 -0
- models/autoencoder/diffusers.py +91 -0
- models/autoencoder/simple_autoenc.py +26 -0
- models/common.py +346 -0
- models/favor_fastattn.py +723 -0
- models/simple_unet.py +233 -0
- models/simple_vit.py +180 -0
- predictors/__init__.py +96 -0
- samplers/__init__.py +7 -0
- samplers/common.py +165 -0
- samplers/ddim.py +10 -0
- samplers/ddpm.py +37 -0
- samplers/euler.py +56 -0
- samplers/heun_sampler.py +27 -0
- samplers/multistep_dpm.py +59 -0
- samplers/rk4_sampler.py +34 -0
- schedulers/__init__.py +6 -0
- schedulers/common.py +98 -0
- schedulers/continuous.py +12 -0
- schedulers/cosine.py +40 -0
- schedulers/discrete.py +74 -0
- schedulers/exp.py +13 -0
- schedulers/karras.py +69 -0
- schedulers/linear.py +14 -0
- schedulers/sqrt.py +10 -0
- trainer/__init__.py +2 -0
- trainer/autoencoder_trainer.py +182 -0
- trainer/diffusion_trainer.py +326 -0
- trainer/simple_trainer.py +540 -0
- trainer/video_diffusion_trainer.py +62 -0
- flaxdiff-0.1.36.1.dist-info/RECORD +0 -6
- flaxdiff-0.1.36.1.dist-info/top_level.txt +0 -1
- /flaxdiff/__init__.py → /__init__.py +0 -0
- {flaxdiff-0.1.36.1.dist-info → flaxdiff-0.1.36.2.dist-info}/WHEEL +0 -0
- /flaxdiff/utils.py → /utils.py +0 -0
@@ -0,0 +1,182 @@
|
|
1
|
+
from flax import linen as nn
|
2
|
+
import jax
|
3
|
+
from typing import Callable
|
4
|
+
from dataclasses import field
|
5
|
+
import jax.numpy as jnp
|
6
|
+
import optax
|
7
|
+
from jax.sharding import Mesh, PartitionSpec as P
|
8
|
+
from jax.experimental.shard_map import shard_map
|
9
|
+
from typing import Dict, Callable, Sequence, Any, Union, Tuple
|
10
|
+
|
11
|
+
from ..schedulers import NoiseScheduler
|
12
|
+
from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
|
13
|
+
|
14
|
+
from flaxdiff.utils import RandomMarkovState
|
15
|
+
|
16
|
+
from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
|
17
|
+
from .diffusion_trainer import TrainState
|
18
|
+
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
|
19
|
+
|
20
|
+
class AutoEncoderTrainer(SimpleTrainer):
|
21
|
+
def __init__(self,
|
22
|
+
model: nn.Module,
|
23
|
+
input_shape: Union[int, int, int],
|
24
|
+
latent_dim: int,
|
25
|
+
spatial_scale: int,
|
26
|
+
optimizer: optax.GradientTransformation,
|
27
|
+
rngs: jax.random.PRNGKey,
|
28
|
+
name: str = "Autoencoder",
|
29
|
+
**kwargs
|
30
|
+
):
|
31
|
+
super().__init__(
|
32
|
+
model=model,
|
33
|
+
input_shapes={"image": input_shape},
|
34
|
+
optimizer=optimizer,
|
35
|
+
rngs=rngs,
|
36
|
+
name=name,
|
37
|
+
**kwargs
|
38
|
+
)
|
39
|
+
self.latent_dim = latent_dim
|
40
|
+
self.spatial_scale = spatial_scale
|
41
|
+
|
42
|
+
|
43
|
+
def generate_states(
|
44
|
+
self,
|
45
|
+
optimizer: optax.GradientTransformation,
|
46
|
+
rngs: jax.random.PRNGKey,
|
47
|
+
existing_state: dict = None,
|
48
|
+
existing_best_state: dict = None,
|
49
|
+
model: nn.Module = None,
|
50
|
+
param_transforms: Callable = None
|
51
|
+
) -> Tuple[TrainState, TrainState]:
|
52
|
+
print("Generating states for DiffusionTrainer")
|
53
|
+
rngs, subkey = jax.random.split(rngs)
|
54
|
+
|
55
|
+
if existing_state == None:
|
56
|
+
input_vars = self.get_input_ones()
|
57
|
+
params = model.init(subkey, **input_vars)
|
58
|
+
new_state = {"params": params, "ema_params": params}
|
59
|
+
else:
|
60
|
+
new_state = existing_state
|
61
|
+
|
62
|
+
if param_transforms is not None:
|
63
|
+
params = param_transforms(params)
|
64
|
+
|
65
|
+
state = TrainState.create(
|
66
|
+
apply_fn=model.apply,
|
67
|
+
params=new_state['params'],
|
68
|
+
ema_params=new_state['ema_params'],
|
69
|
+
tx=optimizer,
|
70
|
+
rngs=rngs,
|
71
|
+
metrics=Metrics.empty()
|
72
|
+
)
|
73
|
+
|
74
|
+
if existing_best_state is not None:
|
75
|
+
best_state = state.replace(
|
76
|
+
params=existing_best_state['params'], ema_params=existing_best_state['ema_params'])
|
77
|
+
else:
|
78
|
+
best_state = state
|
79
|
+
|
80
|
+
return state, best_state
|
81
|
+
|
82
|
+
def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
|
83
|
+
noise_schedule: NoiseScheduler = self.noise_schedule
|
84
|
+
model = self.model
|
85
|
+
model_output_transform = self.model_output_transform
|
86
|
+
loss_fn = self.loss_fn
|
87
|
+
unconditional_prob = self.unconditional_prob
|
88
|
+
|
89
|
+
# Determine the number of unconditional samples
|
90
|
+
num_unconditional = int(batch_size * unconditional_prob)
|
91
|
+
|
92
|
+
nS, nC = null_labels_seq.shape
|
93
|
+
null_labels_seq = jnp.broadcast_to(
|
94
|
+
null_labels_seq, (batch_size, nS, nC))
|
95
|
+
|
96
|
+
distributed_training = self.distributed_training
|
97
|
+
|
98
|
+
autoencoder = self.autoencoder
|
99
|
+
|
100
|
+
# @jax.jit
|
101
|
+
def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, local_device_index):
|
102
|
+
"""Train for a single step."""
|
103
|
+
rng_state, subkey = rng_state.get_random_key()
|
104
|
+
subkey = jax.random.fold_in(subkey, local_device_index.reshape())
|
105
|
+
local_rng_state = RandomMarkovState(subkey)
|
106
|
+
|
107
|
+
images = batch['image']
|
108
|
+
|
109
|
+
if autoencoder is not None:
|
110
|
+
# Convert the images to latent space
|
111
|
+
local_rng_state, rngs = local_rng_state.get_random_key()
|
112
|
+
images = autoencoder.encode(images, rngs)
|
113
|
+
else:
|
114
|
+
# normalize image
|
115
|
+
images = (images - 127.5) / 127.5
|
116
|
+
|
117
|
+
output = text_embedder(
|
118
|
+
input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
|
119
|
+
label_seq = output.last_hidden_state
|
120
|
+
|
121
|
+
# Generate random probabilities to decide how much of this batch will be unconditional
|
122
|
+
|
123
|
+
label_seq = jnp.concat(
|
124
|
+
[null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
|
125
|
+
|
126
|
+
noise_level, local_rng_state = noise_schedule.generate_timesteps(images.shape[0], local_rng_state)
|
127
|
+
|
128
|
+
local_rng_state, rngs = local_rng_state.get_random_key()
|
129
|
+
noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
|
130
|
+
|
131
|
+
rates = noise_schedule.get_rates(noise_level)
|
132
|
+
noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
|
133
|
+
images, noise, rates)
|
134
|
+
|
135
|
+
def model_loss(params):
|
136
|
+
preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
|
137
|
+
preds = model_output_transform.pred_transform(
|
138
|
+
noisy_images, preds, rates)
|
139
|
+
nloss = loss_fn(preds, expected_output)
|
140
|
+
# nloss = jnp.mean(nloss, axis=1)
|
141
|
+
nloss *= noise_schedule.get_weights(noise_level)
|
142
|
+
nloss = jnp.mean(nloss)
|
143
|
+
loss = nloss
|
144
|
+
return loss
|
145
|
+
|
146
|
+
loss, grads = jax.value_and_grad(model_loss)(train_state.params)
|
147
|
+
if distributed_training:
|
148
|
+
grads = jax.lax.pmean(grads, "data")
|
149
|
+
loss = jax.lax.pmean(loss, "data")
|
150
|
+
train_state = train_state.apply_gradients(grads=grads)
|
151
|
+
train_state = train_state.apply_ema(self.ema_decay)
|
152
|
+
return train_state, loss, rng_state
|
153
|
+
|
154
|
+
if distributed_training:
|
155
|
+
train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')),
|
156
|
+
out_specs=(P(), P(), P()))
|
157
|
+
train_step = jax.jit(train_step)
|
158
|
+
|
159
|
+
return train_step
|
160
|
+
|
161
|
+
def _define_compute_metrics(self):
|
162
|
+
@jax.jit
|
163
|
+
def compute_metrics(state: TrainState, expected, pred):
|
164
|
+
loss = jnp.mean(jnp.square(pred - expected))
|
165
|
+
metric_updates = state.metrics.single_from_model_output(loss=loss)
|
166
|
+
metrics = state.metrics.merge(metric_updates)
|
167
|
+
state = state.replace(metrics=metrics)
|
168
|
+
return state
|
169
|
+
return compute_metrics
|
170
|
+
|
171
|
+
def fit(self, data, steps_per_epoch, epochs):
|
172
|
+
null_labels_full = data['null_labels_full']
|
173
|
+
local_batch_size = data['local_batch_size']
|
174
|
+
text_embedder = data['model']
|
175
|
+
super().fit(data, steps_per_epoch, epochs, {
|
176
|
+
"batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder})
|
177
|
+
|
178
|
+
def boolean_string(s):
|
179
|
+
if type(s) == bool:
|
180
|
+
return s
|
181
|
+
return s == 'True'
|
182
|
+
|
@@ -0,0 +1,326 @@
|
|
1
|
+
import flax
|
2
|
+
from flax import linen as nn
|
3
|
+
import jax
|
4
|
+
from typing import Callable
|
5
|
+
from dataclasses import field
|
6
|
+
import jax.numpy as jnp
|
7
|
+
import traceback
|
8
|
+
import optax
|
9
|
+
import functools
|
10
|
+
from jax.sharding import Mesh, PartitionSpec as P
|
11
|
+
from jax.experimental.shard_map import shard_map
|
12
|
+
from typing import Dict, Callable, Sequence, Any, Union, Tuple, Type
|
13
|
+
|
14
|
+
from ..schedulers import NoiseScheduler
|
15
|
+
from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
|
16
|
+
from ..samplers.common import DiffusionSampler
|
17
|
+
|
18
|
+
from flaxdiff.utils import RandomMarkovState
|
19
|
+
|
20
|
+
from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
|
21
|
+
|
22
|
+
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
|
23
|
+
from flax.training import dynamic_scale as dynamic_scale_lib
|
24
|
+
from flaxdiff.utils import TextEncoder, ConditioningEncoder
|
25
|
+
|
26
|
+
class TrainState(SimpleTrainState):
|
27
|
+
rngs: jax.random.PRNGKey
|
28
|
+
ema_params: dict
|
29
|
+
|
30
|
+
def apply_ema(self, decay: float = 0.999):
|
31
|
+
new_ema_params = jax.tree_util.tree_map(
|
32
|
+
lambda ema, param: decay * ema + (1 - decay) * param,
|
33
|
+
self.ema_params,
|
34
|
+
self.params,
|
35
|
+
)
|
36
|
+
return self.replace(ema_params=new_ema_params)
|
37
|
+
|
38
|
+
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
|
39
|
+
|
40
|
+
class DiffusionTrainer(SimpleTrainer):
|
41
|
+
noise_schedule: NoiseScheduler
|
42
|
+
model_output_transform: DiffusionPredictionTransform
|
43
|
+
ema_decay: float = 0.999
|
44
|
+
|
45
|
+
def __init__(self,
|
46
|
+
model: nn.Module,
|
47
|
+
input_shapes: Dict[str, Tuple[int]],
|
48
|
+
optimizer: optax.GradientTransformation,
|
49
|
+
noise_schedule: NoiseScheduler,
|
50
|
+
rngs: jax.random.PRNGKey,
|
51
|
+
unconditional_prob: float = 0.12,
|
52
|
+
name: str = "Diffusion",
|
53
|
+
model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
|
54
|
+
autoencoder: AutoEncoder = None,
|
55
|
+
encoder: ConditioningEncoder = None,
|
56
|
+
**kwargs
|
57
|
+
):
|
58
|
+
super().__init__(
|
59
|
+
model=model,
|
60
|
+
input_shapes=input_shapes,
|
61
|
+
optimizer=optimizer,
|
62
|
+
rngs=rngs,
|
63
|
+
name=name,
|
64
|
+
**kwargs
|
65
|
+
)
|
66
|
+
self.noise_schedule = noise_schedule
|
67
|
+
self.model_output_transform = model_output_transform
|
68
|
+
self.unconditional_prob = unconditional_prob
|
69
|
+
|
70
|
+
self.autoencoder = autoencoder
|
71
|
+
self.encoder = encoder
|
72
|
+
|
73
|
+
def generate_states(
|
74
|
+
self,
|
75
|
+
optimizer: optax.GradientTransformation,
|
76
|
+
rngs: jax.random.PRNGKey,
|
77
|
+
existing_state: dict = None,
|
78
|
+
existing_best_state: dict = None,
|
79
|
+
model: nn.Module = None,
|
80
|
+
param_transforms: Callable = None,
|
81
|
+
use_dynamic_scale: bool = False
|
82
|
+
) -> Tuple[TrainState, TrainState]:
|
83
|
+
print("Generating states for DiffusionTrainer")
|
84
|
+
rngs, subkey = jax.random.split(rngs)
|
85
|
+
|
86
|
+
if existing_state == None:
|
87
|
+
input_vars = self.get_input_ones()
|
88
|
+
params = model.init(subkey, **input_vars)
|
89
|
+
new_state = {"params": params, "ema_params": params}
|
90
|
+
else:
|
91
|
+
new_state = existing_state
|
92
|
+
|
93
|
+
if param_transforms is not None:
|
94
|
+
params = param_transforms(params)
|
95
|
+
|
96
|
+
state = TrainState.create(
|
97
|
+
apply_fn=model.apply,
|
98
|
+
params=new_state['params'],
|
99
|
+
ema_params=new_state['ema_params'],
|
100
|
+
tx=optimizer,
|
101
|
+
rngs=rngs,
|
102
|
+
metrics=Metrics.empty(),
|
103
|
+
dynamic_scale = dynamic_scale_lib.DynamicScale() if use_dynamic_scale else None
|
104
|
+
)
|
105
|
+
|
106
|
+
if existing_best_state is not None:
|
107
|
+
best_state = state.replace(
|
108
|
+
params=existing_best_state['params'], ema_params=existing_best_state['ema_params'])
|
109
|
+
else:
|
110
|
+
best_state = state
|
111
|
+
|
112
|
+
return state, best_state
|
113
|
+
|
114
|
+
def _define_train_step(self, batch_size):
|
115
|
+
noise_schedule: NoiseScheduler = self.noise_schedule
|
116
|
+
model = self.model
|
117
|
+
model_output_transform = self.model_output_transform
|
118
|
+
loss_fn = self.loss_fn
|
119
|
+
unconditional_prob = self.unconditional_prob
|
120
|
+
|
121
|
+
# Determine the number of unconditional samples
|
122
|
+
num_unconditional = int(batch_size * unconditional_prob)
|
123
|
+
|
124
|
+
null_labels_full = self.encoder([""])
|
125
|
+
null_labels_seq = jnp.array(null_labels_full[0], dtype=jnp.float16)
|
126
|
+
|
127
|
+
conditioning_encoder = self.encoder
|
128
|
+
|
129
|
+
nS, nC = null_labels_seq.shape
|
130
|
+
null_labels_seq = jnp.broadcast_to(
|
131
|
+
null_labels_seq, (batch_size, nS, nC))
|
132
|
+
|
133
|
+
distributed_training = self.distributed_training
|
134
|
+
|
135
|
+
autoencoder = self.autoencoder
|
136
|
+
|
137
|
+
# @jax.jit
|
138
|
+
def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, local_device_index):
|
139
|
+
"""Train for a single step."""
|
140
|
+
rng_state, subkey = rng_state.get_random_key()
|
141
|
+
subkey = jax.random.fold_in(subkey, local_device_index.reshape())
|
142
|
+
local_rng_state = RandomMarkovState(subkey)
|
143
|
+
|
144
|
+
images = batch['image']
|
145
|
+
|
146
|
+
# First get the standard deviation of the images
|
147
|
+
# std = jnp.std(images, axis=(1, 2, 3))
|
148
|
+
# is_non_zero = (std > 0)
|
149
|
+
|
150
|
+
images = jnp.array(images, dtype=jnp.float32)
|
151
|
+
# normalize image
|
152
|
+
images = (images - 127.5) / 127.5
|
153
|
+
|
154
|
+
if autoencoder is not None:
|
155
|
+
# Convert the images to latent space
|
156
|
+
local_rng_state, rngs = local_rng_state.get_random_key()
|
157
|
+
images = autoencoder.encode(images, rngs)
|
158
|
+
|
159
|
+
label_seq = conditioning_encoder.encode_from_tokens(batch)
|
160
|
+
|
161
|
+
# Generate random probabilities to decide how much of this batch will be unconditional
|
162
|
+
|
163
|
+
label_seq = jnp.concat(
|
164
|
+
[null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
|
165
|
+
|
166
|
+
noise_level, local_rng_state = noise_schedule.generate_timesteps(images.shape[0], local_rng_state)
|
167
|
+
|
168
|
+
local_rng_state, rngs = local_rng_state.get_random_key()
|
169
|
+
noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
|
170
|
+
|
171
|
+
rates = noise_schedule.get_rates(noise_level)
|
172
|
+
noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
|
173
|
+
images, noise, rates)
|
174
|
+
|
175
|
+
def model_loss(params):
|
176
|
+
preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
|
177
|
+
preds = model_output_transform.pred_transform(
|
178
|
+
noisy_images, preds, rates)
|
179
|
+
nloss = loss_fn(preds, expected_output)
|
180
|
+
# Ignore the loss contribution of images with zero standard deviation
|
181
|
+
nloss *= noise_schedule.get_weights(noise_level)
|
182
|
+
# nloss = jnp.mean(nloss, axis=(1,2,3))
|
183
|
+
# nloss = jnp.where(is_non_zero, nloss, 0)
|
184
|
+
# nloss = jnp.mean(nloss, where=nloss != 0)
|
185
|
+
nloss = jnp.mean(nloss)
|
186
|
+
loss = nloss
|
187
|
+
return loss
|
188
|
+
|
189
|
+
|
190
|
+
if train_state.dynamic_scale is not None:
|
191
|
+
# dynamic scale takes care of averaging gradients across replicas
|
192
|
+
grad_fn = train_state.dynamic_scale.value_and_grad(
|
193
|
+
model_loss, axis_name="data"
|
194
|
+
)
|
195
|
+
dynamic_scale, is_fin, loss, grads = grad_fn(train_state.params)
|
196
|
+
train_state = train_state.replace(dynamic_scale=dynamic_scale)
|
197
|
+
else:
|
198
|
+
grad_fn = jax.value_and_grad(model_loss)
|
199
|
+
loss, grads = grad_fn(train_state.params)
|
200
|
+
if distributed_training:
|
201
|
+
grads = jax.lax.pmean(grads, "data")
|
202
|
+
|
203
|
+
new_state = train_state.apply_gradients(grads=grads)
|
204
|
+
|
205
|
+
if train_state.dynamic_scale is not None:
|
206
|
+
# if is_fin == False the gradients contain Inf/NaNs and optimizer state and
|
207
|
+
# params should be restored (= skip this step).
|
208
|
+
select_fn = functools.partial(jnp.where, is_fin)
|
209
|
+
new_state = new_state.replace(
|
210
|
+
opt_state=jax.tree_util.tree_map(
|
211
|
+
select_fn, new_state.opt_state, train_state.opt_state
|
212
|
+
),
|
213
|
+
params=jax.tree_util.tree_map(
|
214
|
+
select_fn, new_state.params, train_state.params
|
215
|
+
),
|
216
|
+
)
|
217
|
+
|
218
|
+
train_state = new_state.apply_ema(self.ema_decay)
|
219
|
+
|
220
|
+
if distributed_training:
|
221
|
+
loss = jax.lax.pmean(loss, "data")
|
222
|
+
return train_state, loss, rng_state
|
223
|
+
|
224
|
+
if distributed_training:
|
225
|
+
train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')),
|
226
|
+
out_specs=(P(), P(), P()))
|
227
|
+
train_step = jax.jit(train_step)
|
228
|
+
|
229
|
+
return train_step
|
230
|
+
|
231
|
+
def _define_vaidation_step(self, sampler_class: Type[DiffusionSampler]):
|
232
|
+
model = self.model
|
233
|
+
encoder = self.encoder
|
234
|
+
autoencoder = self.autoencoder
|
235
|
+
|
236
|
+
null_labels_full = encoder([""])
|
237
|
+
null_labels_full = null_labels_full.astype(jnp.float16)
|
238
|
+
# null_labels_seq = jnp.array(null_labels_full[0], dtype=jnp.float16)
|
239
|
+
|
240
|
+
def generate_sampler(state: TrainState):
|
241
|
+
sampler = sampler_class(
|
242
|
+
model=model,
|
243
|
+
params=state.ema_params,
|
244
|
+
noise_schedule=self.noise_schedule,
|
245
|
+
model_output_transform=self.model_output_transform,
|
246
|
+
image_size=self.input_shapes['x'][0],
|
247
|
+
null_labels_seq=null_labels_full,
|
248
|
+
autoencoder=autoencoder,
|
249
|
+
)
|
250
|
+
return sampler
|
251
|
+
|
252
|
+
def generate_samples(
|
253
|
+
batch,
|
254
|
+
sampler: DiffusionSampler,
|
255
|
+
diffusion_steps: int,
|
256
|
+
):
|
257
|
+
labels_seq = encoder.encode_from_tokens(batch)
|
258
|
+
labels_seq = jnp.array(labels_seq, dtype=jnp.float16)
|
259
|
+
samples = sampler.generate_images(
|
260
|
+
num_images=len(labels_seq),
|
261
|
+
diffusion_steps=diffusion_steps,
|
262
|
+
start_step=1000,
|
263
|
+
end_step=0,
|
264
|
+
priors=None,
|
265
|
+
model_conditioning_inputs=(labels_seq,),
|
266
|
+
)
|
267
|
+
return samples
|
268
|
+
|
269
|
+
return generate_sampler, generate_samples
|
270
|
+
|
271
|
+
def validation_loop(
|
272
|
+
self,
|
273
|
+
val_state: SimpleTrainState,
|
274
|
+
val_step_fn: Callable,
|
275
|
+
val_ds,
|
276
|
+
val_steps_per_epoch,
|
277
|
+
current_step,
|
278
|
+
diffusion_steps=200,
|
279
|
+
):
|
280
|
+
generate_sampler, generate_samples = val_step_fn
|
281
|
+
|
282
|
+
sampler = generate_sampler(val_state)
|
283
|
+
|
284
|
+
val_ds = iter(val_ds()) if val_ds else None
|
285
|
+
# Evaluation step
|
286
|
+
try:
|
287
|
+
samples = generate_samples(
|
288
|
+
next(val_ds),
|
289
|
+
sampler,
|
290
|
+
diffusion_steps,
|
291
|
+
)
|
292
|
+
|
293
|
+
# Put each sample on wandb
|
294
|
+
if self.wandb:
|
295
|
+
import numpy as np
|
296
|
+
from wandb import Image as wandbImage
|
297
|
+
wandb_images = []
|
298
|
+
for i in range(samples.shape[0]):
|
299
|
+
# convert the sample to numpy
|
300
|
+
sample = np.array(samples[i])
|
301
|
+
# denormalize the image
|
302
|
+
sample = (sample + 1) * 127.5
|
303
|
+
sample = np.clip(sample, 0, 255).astype(np.uint8)
|
304
|
+
# add the image to the list
|
305
|
+
wandb_images.append(sample)
|
306
|
+
# log the images to wandb
|
307
|
+
self.wandb.log({
|
308
|
+
f"sample_{i}": wandbImage(sample, caption=f"Sample {i} at step {current_step}")
|
309
|
+
}, step=current_step)
|
310
|
+
except Exception as e:
|
311
|
+
print("Error logging images to wandb", e)
|
312
|
+
traceback.print_exc()
|
313
|
+
|
314
|
+
def fit(self, data, training_steps_per_epoch, epochs, val_steps_per_epoch=8, sampler_class=None):
|
315
|
+
local_batch_size = data['local_batch_size']
|
316
|
+
validation_step_args = {
|
317
|
+
"sampler_class": sampler_class,
|
318
|
+
}
|
319
|
+
super().fit(
|
320
|
+
data,
|
321
|
+
train_steps_per_epoch=training_steps_per_epoch,
|
322
|
+
epochs=epochs,
|
323
|
+
train_step_args={"batch_size": local_batch_size},
|
324
|
+
val_steps_per_epoch=val_steps_per_epoch,
|
325
|
+
validation_step_args=validation_step_args,
|
326
|
+
)
|