flaxdiff 0.1.38__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/__init__.py +5 -1
- flaxdiff/data/benchmark_decord.py +443 -0
- flaxdiff/data/dataloaders.py +608 -0
- flaxdiff/data/dataset_map.py +61 -6
- flaxdiff/data/online_loader.py +779 -150
- flaxdiff/data/sources/audio_utils.py +142 -0
- flaxdiff/data/sources/av_example.py +125 -0
- flaxdiff/data/sources/av_utils.py +590 -0
- flaxdiff/data/sources/base.py +129 -0
- flaxdiff/data/sources/images.py +309 -0
- flaxdiff/data/sources/utils.py +158 -0
- flaxdiff/data/sources/videos.py +250 -0
- flaxdiff/data/sources/voxceleb2.py +412 -0
- flaxdiff/inference/__init__.py +0 -0
- flaxdiff/inference/pipeline.py +260 -0
- flaxdiff/inference/utils.py +320 -0
- flaxdiff/inputs/__init__.py +173 -0
- flaxdiff/inputs/encoders.py +98 -0
- flaxdiff/models/__init__.py +2 -1
- flaxdiff/models/attention.py +22 -16
- flaxdiff/models/autoencoder/autoencoder.py +141 -9
- flaxdiff/models/autoencoder/diffusers.py +88 -25
- flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
- flaxdiff/models/common.py +8 -18
- flaxdiff/models/simple_unet.py +6 -17
- flaxdiff/models/simple_vit.py +9 -13
- flaxdiff/models/unet_3d.py +446 -0
- flaxdiff/models/unet_3d_blocks.py +505 -0
- flaxdiff/samplers/common.py +358 -96
- flaxdiff/samplers/ddim.py +44 -5
- flaxdiff/schedulers/karras.py +20 -12
- flaxdiff/trainer/__init__.py +2 -1
- flaxdiff/trainer/autoencoder_trainer.py +1 -2
- flaxdiff/trainer/diffusion_trainer.py +35 -29
- flaxdiff/trainer/general_diffusion_trainer.py +583 -0
- flaxdiff/trainer/simple_trainer.py +51 -16
- flaxdiff/utils.py +128 -57
- {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
- flaxdiff-0.2.0.dist-info/RECORD +64 -0
- {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
- flaxdiff/data/datasets.py +0 -169
- flaxdiff/data/sources/gcs.py +0 -81
- flaxdiff/data/sources/tfds.py +0 -79
- flaxdiff/trainer/video_diffusion_trainer.py +0 -62
- flaxdiff-0.1.38.dist-info/RECORD +0 -50
- {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
|
|
22
22
|
|
23
23
|
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
|
24
24
|
from flax.training import dynamic_scale as dynamic_scale_lib
|
25
|
-
from flaxdiff.
|
25
|
+
from flaxdiff.inputs import TextEncoder, ConditioningEncoder
|
26
26
|
|
27
27
|
class TrainState(SimpleTrainState):
|
28
28
|
rngs: jax.random.PRNGKey
|
@@ -42,6 +42,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
42
42
|
noise_schedule: NoiseScheduler
|
43
43
|
model_output_transform: DiffusionPredictionTransform
|
44
44
|
ema_decay: float = 0.999
|
45
|
+
native_resolution: int = None
|
45
46
|
|
46
47
|
def __init__(self,
|
47
48
|
model: nn.Module,
|
@@ -54,6 +55,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
54
55
|
model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
|
55
56
|
autoencoder: AutoEncoder = None,
|
56
57
|
encoder: ConditioningEncoder = None,
|
58
|
+
native_resolution: int = None,
|
57
59
|
**kwargs
|
58
60
|
):
|
59
61
|
super().__init__(
|
@@ -68,6 +70,20 @@ class DiffusionTrainer(SimpleTrainer):
|
|
68
70
|
self.model_output_transform = model_output_transform
|
69
71
|
self.unconditional_prob = unconditional_prob
|
70
72
|
|
73
|
+
if native_resolution is None:
|
74
|
+
if 'image' in input_shapes:
|
75
|
+
native_resolution = input_shapes['image'][1]
|
76
|
+
elif 'x' in input_shapes:
|
77
|
+
native_resolution = input_shapes['x'][1]
|
78
|
+
elif 'sample' in input_shapes:
|
79
|
+
native_resolution = input_shapes['sample'][1]
|
80
|
+
else:
|
81
|
+
raise ValueError("No image input shape found in input shapes")
|
82
|
+
if autoencoder is not None:
|
83
|
+
native_resolution = native_resolution * 8
|
84
|
+
|
85
|
+
self.native_resolution = native_resolution
|
86
|
+
|
71
87
|
self.autoencoder = autoencoder
|
72
88
|
self.encoder = encoder
|
73
89
|
|
@@ -118,9 +134,6 @@ class DiffusionTrainer(SimpleTrainer):
|
|
118
134
|
model_output_transform = self.model_output_transform
|
119
135
|
loss_fn = self.loss_fn
|
120
136
|
unconditional_prob = self.unconditional_prob
|
121
|
-
|
122
|
-
# Determine the number of unconditional samples
|
123
|
-
num_unconditional = int(batch_size * unconditional_prob)
|
124
137
|
|
125
138
|
null_labels_full = self.encoder([""])
|
126
139
|
null_labels_seq = jnp.array(null_labels_full[0], dtype=jnp.float16)
|
@@ -159,12 +172,19 @@ class DiffusionTrainer(SimpleTrainer):
|
|
159
172
|
local_rng_state, rngs = local_rng_state.get_random_key()
|
160
173
|
images = autoencoder.encode(images, rngs)
|
161
174
|
|
162
|
-
label_seq = conditioning_encoder.encode_from_tokens(batch)
|
175
|
+
label_seq = conditioning_encoder.encode_from_tokens(batch['text'])
|
163
176
|
|
164
177
|
# Generate random probabilities to decide how much of this batch will be unconditional
|
178
|
+
local_rng_state, uncond_key = local_rng_state.get_random_key()
|
179
|
+
# Efficient way to determine unconditional samples for JIT compatibility
|
180
|
+
uncond_mask = jax.random.bernoulli(
|
181
|
+
uncond_key,
|
182
|
+
shape=(local_batch_size,),
|
183
|
+
p=unconditional_prob
|
184
|
+
)
|
185
|
+
num_unconditional = jnp.sum(uncond_mask).astype(jnp.int32)
|
165
186
|
|
166
|
-
label_seq = jnp.
|
167
|
-
[null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
|
187
|
+
label_seq = jnp.concatenate([null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
|
168
188
|
|
169
189
|
noise_level, local_rng_state = noise_schedule.generate_timesteps(local_batch_size, local_rng_state)
|
170
190
|
|
@@ -200,21 +220,6 @@ class DiffusionTrainer(SimpleTrainer):
|
|
200
220
|
loss, grads = grad_fn(train_state.params)
|
201
221
|
if distributed_training:
|
202
222
|
grads = jax.lax.pmean(grads, "data")
|
203
|
-
|
204
|
-
# # check gradients for NaN/Inf
|
205
|
-
# has_nan_or_inf = jax.tree_util.tree_reduce(
|
206
|
-
# lambda acc, x: jnp.logical_or(acc, jnp.logical_or(jnp.isnan(x).any(), jnp.isinf(x).any())),
|
207
|
-
# grads,
|
208
|
-
# initializer=False
|
209
|
-
# )
|
210
|
-
|
211
|
-
# # Only apply gradients if they're valid
|
212
|
-
# new_state = jax.lax.cond(
|
213
|
-
# has_nan_or_inf,
|
214
|
-
# lambda _: train_state, # Skip gradient update
|
215
|
-
# lambda _: train_state.apply_gradients(grads=grads),
|
216
|
-
# operand=None
|
217
|
-
# )
|
218
223
|
|
219
224
|
new_state = train_state.apply_gradients(grads=grads)
|
220
225
|
|
@@ -231,11 +236,11 @@ class DiffusionTrainer(SimpleTrainer):
|
|
231
236
|
),
|
232
237
|
)
|
233
238
|
|
234
|
-
|
239
|
+
new_state = new_state.apply_ema(self.ema_decay)
|
235
240
|
|
236
241
|
if distributed_training:
|
237
242
|
loss = jax.lax.pmean(loss, "data")
|
238
|
-
return
|
243
|
+
return new_state, loss, rng_state
|
239
244
|
|
240
245
|
if distributed_training:
|
241
246
|
train_step = shard_map(
|
@@ -251,7 +256,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
251
256
|
|
252
257
|
return train_step
|
253
258
|
|
254
|
-
def
|
259
|
+
def _define_validation_step(self, sampler_class: Type[DiffusionSampler]=DDIMSampler, sampling_noise_schedule: NoiseScheduler=None):
|
255
260
|
model = self.model
|
256
261
|
encoder = self.encoder
|
257
262
|
autoencoder = self.autoencoder
|
@@ -260,7 +265,9 @@ class DiffusionTrainer(SimpleTrainer):
|
|
260
265
|
null_labels_full = null_labels_full.astype(jnp.float16)
|
261
266
|
# null_labels_seq = jnp.array(null_labels_full[0], dtype=jnp.float16)
|
262
267
|
|
263
|
-
if
|
268
|
+
if self.native_resolution is not None:
|
269
|
+
image_size = self.native_resolution
|
270
|
+
elif 'image' in self.input_shapes:
|
264
271
|
image_size = self.input_shapes['image'][1]
|
265
272
|
elif 'x' in self.input_shapes:
|
266
273
|
image_size = self.input_shapes['x'][1]
|
@@ -271,10 +278,8 @@ class DiffusionTrainer(SimpleTrainer):
|
|
271
278
|
|
272
279
|
sampler = sampler_class(
|
273
280
|
model=model,
|
274
|
-
params=None,
|
275
281
|
noise_schedule=self.noise_schedule if sampling_noise_schedule is None else sampling_noise_schedule,
|
276
282
|
model_output_transform=self.model_output_transform,
|
277
|
-
image_size=image_size,
|
278
283
|
null_labels_seq=null_labels_full,
|
279
284
|
autoencoder=autoencoder,
|
280
285
|
guidance_scale=3.0,
|
@@ -290,7 +295,8 @@ class DiffusionTrainer(SimpleTrainer):
|
|
290
295
|
labels_seq = jnp.array(labels_seq, dtype=jnp.float16)
|
291
296
|
samples = sampler.generate_images(
|
292
297
|
params=val_state.ema_params,
|
293
|
-
|
298
|
+
resolution=image_size,
|
299
|
+
num_samples=len(labels_seq),
|
294
300
|
diffusion_steps=diffusion_steps,
|
295
301
|
start_step=1000,
|
296
302
|
end_step=0,
|