flaxdiff 0.1.36.1__py3-none-any.whl → 0.1.36.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. flaxdiff/data/__init__.py +1 -0
  2. flaxdiff/data/dataset_map.py +71 -0
  3. flaxdiff/data/datasets.py +169 -0
  4. flaxdiff/data/online_loader.py +363 -0
  5. flaxdiff/data/sources/gcs.py +81 -0
  6. flaxdiff/data/sources/tfds.py +67 -0
  7. flaxdiff/metrics/inception.py +658 -0
  8. flaxdiff/metrics/utils.py +49 -0
  9. flaxdiff/models/__init__.py +1 -0
  10. flaxdiff/models/attention.py +368 -0
  11. flaxdiff/models/autoencoder/__init__.py +2 -0
  12. flaxdiff/models/autoencoder/autoencoder.py +19 -0
  13. flaxdiff/models/autoencoder/diffusers.py +91 -0
  14. flaxdiff/models/autoencoder/simple_autoenc.py +26 -0
  15. flaxdiff/models/common.py +346 -0
  16. flaxdiff/models/favor_fastattn.py +723 -0
  17. flaxdiff/models/simple_unet.py +233 -0
  18. flaxdiff/models/simple_vit.py +180 -0
  19. flaxdiff/predictors/__init__.py +96 -0
  20. flaxdiff/samplers/__init__.py +7 -0
  21. flaxdiff/samplers/common.py +165 -0
  22. flaxdiff/samplers/ddim.py +10 -0
  23. flaxdiff/samplers/ddpm.py +37 -0
  24. flaxdiff/samplers/euler.py +56 -0
  25. flaxdiff/samplers/heun_sampler.py +27 -0
  26. flaxdiff/samplers/multistep_dpm.py +59 -0
  27. flaxdiff/samplers/rk4_sampler.py +34 -0
  28. flaxdiff/schedulers/__init__.py +6 -0
  29. flaxdiff/schedulers/common.py +98 -0
  30. flaxdiff/schedulers/continuous.py +12 -0
  31. flaxdiff/schedulers/cosine.py +40 -0
  32. flaxdiff/schedulers/discrete.py +74 -0
  33. flaxdiff/schedulers/exp.py +13 -0
  34. flaxdiff/schedulers/karras.py +69 -0
  35. flaxdiff/schedulers/linear.py +14 -0
  36. flaxdiff/schedulers/sqrt.py +10 -0
  37. flaxdiff/trainer/__init__.py +2 -0
  38. flaxdiff/trainer/autoencoder_trainer.py +182 -0
  39. flaxdiff/trainer/diffusion_trainer.py +326 -0
  40. flaxdiff/trainer/simple_trainer.py +540 -0
  41. flaxdiff/trainer/video_diffusion_trainer.py +62 -0
  42. {flaxdiff-0.1.36.1.dist-info → flaxdiff-0.1.36.3.dist-info}/METADATA +1 -1
  43. flaxdiff-0.1.36.3.dist-info/RECORD +47 -0
  44. flaxdiff-0.1.36.1.dist-info/RECORD +0 -6
  45. {flaxdiff-0.1.36.1.dist-info → flaxdiff-0.1.36.3.dist-info}/WHEEL +0 -0
  46. {flaxdiff-0.1.36.1.dist-info → flaxdiff-0.1.36.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,182 @@
1
+ from flax import linen as nn
2
+ import jax
3
+ from typing import Callable
4
+ from dataclasses import field
5
+ import jax.numpy as jnp
6
+ import optax
7
+ from jax.sharding import Mesh, PartitionSpec as P
8
+ from jax.experimental.shard_map import shard_map
9
+ from typing import Dict, Callable, Sequence, Any, Union, Tuple
10
+
11
+ from ..schedulers import NoiseScheduler
12
+ from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
13
+
14
+ from flaxdiff.utils import RandomMarkovState
15
+
16
+ from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
17
+ from .diffusion_trainer import TrainState
18
+ from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
19
+
20
+ class AutoEncoderTrainer(SimpleTrainer):
21
+ def __init__(self,
22
+ model: nn.Module,
23
+ input_shape: Union[int, int, int],
24
+ latent_dim: int,
25
+ spatial_scale: int,
26
+ optimizer: optax.GradientTransformation,
27
+ rngs: jax.random.PRNGKey,
28
+ name: str = "Autoencoder",
29
+ **kwargs
30
+ ):
31
+ super().__init__(
32
+ model=model,
33
+ input_shapes={"image": input_shape},
34
+ optimizer=optimizer,
35
+ rngs=rngs,
36
+ name=name,
37
+ **kwargs
38
+ )
39
+ self.latent_dim = latent_dim
40
+ self.spatial_scale = spatial_scale
41
+
42
+
43
+ def generate_states(
44
+ self,
45
+ optimizer: optax.GradientTransformation,
46
+ rngs: jax.random.PRNGKey,
47
+ existing_state: dict = None,
48
+ existing_best_state: dict = None,
49
+ model: nn.Module = None,
50
+ param_transforms: Callable = None
51
+ ) -> Tuple[TrainState, TrainState]:
52
+ print("Generating states for DiffusionTrainer")
53
+ rngs, subkey = jax.random.split(rngs)
54
+
55
+ if existing_state == None:
56
+ input_vars = self.get_input_ones()
57
+ params = model.init(subkey, **input_vars)
58
+ new_state = {"params": params, "ema_params": params}
59
+ else:
60
+ new_state = existing_state
61
+
62
+ if param_transforms is not None:
63
+ params = param_transforms(params)
64
+
65
+ state = TrainState.create(
66
+ apply_fn=model.apply,
67
+ params=new_state['params'],
68
+ ema_params=new_state['ema_params'],
69
+ tx=optimizer,
70
+ rngs=rngs,
71
+ metrics=Metrics.empty()
72
+ )
73
+
74
+ if existing_best_state is not None:
75
+ best_state = state.replace(
76
+ params=existing_best_state['params'], ema_params=existing_best_state['ema_params'])
77
+ else:
78
+ best_state = state
79
+
80
+ return state, best_state
81
+
82
+ def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
83
+ noise_schedule: NoiseScheduler = self.noise_schedule
84
+ model = self.model
85
+ model_output_transform = self.model_output_transform
86
+ loss_fn = self.loss_fn
87
+ unconditional_prob = self.unconditional_prob
88
+
89
+ # Determine the number of unconditional samples
90
+ num_unconditional = int(batch_size * unconditional_prob)
91
+
92
+ nS, nC = null_labels_seq.shape
93
+ null_labels_seq = jnp.broadcast_to(
94
+ null_labels_seq, (batch_size, nS, nC))
95
+
96
+ distributed_training = self.distributed_training
97
+
98
+ autoencoder = self.autoencoder
99
+
100
+ # @jax.jit
101
+ def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, local_device_index):
102
+ """Train for a single step."""
103
+ rng_state, subkey = rng_state.get_random_key()
104
+ subkey = jax.random.fold_in(subkey, local_device_index.reshape())
105
+ local_rng_state = RandomMarkovState(subkey)
106
+
107
+ images = batch['image']
108
+
109
+ if autoencoder is not None:
110
+ # Convert the images to latent space
111
+ local_rng_state, rngs = local_rng_state.get_random_key()
112
+ images = autoencoder.encode(images, rngs)
113
+ else:
114
+ # normalize image
115
+ images = (images - 127.5) / 127.5
116
+
117
+ output = text_embedder(
118
+ input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
119
+ label_seq = output.last_hidden_state
120
+
121
+ # Generate random probabilities to decide how much of this batch will be unconditional
122
+
123
+ label_seq = jnp.concat(
124
+ [null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
125
+
126
+ noise_level, local_rng_state = noise_schedule.generate_timesteps(images.shape[0], local_rng_state)
127
+
128
+ local_rng_state, rngs = local_rng_state.get_random_key()
129
+ noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
130
+
131
+ rates = noise_schedule.get_rates(noise_level)
132
+ noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
133
+ images, noise, rates)
134
+
135
+ def model_loss(params):
136
+ preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
137
+ preds = model_output_transform.pred_transform(
138
+ noisy_images, preds, rates)
139
+ nloss = loss_fn(preds, expected_output)
140
+ # nloss = jnp.mean(nloss, axis=1)
141
+ nloss *= noise_schedule.get_weights(noise_level)
142
+ nloss = jnp.mean(nloss)
143
+ loss = nloss
144
+ return loss
145
+
146
+ loss, grads = jax.value_and_grad(model_loss)(train_state.params)
147
+ if distributed_training:
148
+ grads = jax.lax.pmean(grads, "data")
149
+ loss = jax.lax.pmean(loss, "data")
150
+ train_state = train_state.apply_gradients(grads=grads)
151
+ train_state = train_state.apply_ema(self.ema_decay)
152
+ return train_state, loss, rng_state
153
+
154
+ if distributed_training:
155
+ train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')),
156
+ out_specs=(P(), P(), P()))
157
+ train_step = jax.jit(train_step)
158
+
159
+ return train_step
160
+
161
+ def _define_compute_metrics(self):
162
+ @jax.jit
163
+ def compute_metrics(state: TrainState, expected, pred):
164
+ loss = jnp.mean(jnp.square(pred - expected))
165
+ metric_updates = state.metrics.single_from_model_output(loss=loss)
166
+ metrics = state.metrics.merge(metric_updates)
167
+ state = state.replace(metrics=metrics)
168
+ return state
169
+ return compute_metrics
170
+
171
+ def fit(self, data, steps_per_epoch, epochs):
172
+ null_labels_full = data['null_labels_full']
173
+ local_batch_size = data['local_batch_size']
174
+ text_embedder = data['model']
175
+ super().fit(data, steps_per_epoch, epochs, {
176
+ "batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder})
177
+
178
+ def boolean_string(s):
179
+ if type(s) == bool:
180
+ return s
181
+ return s == 'True'
182
+
@@ -0,0 +1,326 @@
1
+ import flax
2
+ from flax import linen as nn
3
+ import jax
4
+ from typing import Callable
5
+ from dataclasses import field
6
+ import jax.numpy as jnp
7
+ import traceback
8
+ import optax
9
+ import functools
10
+ from jax.sharding import Mesh, PartitionSpec as P
11
+ from jax.experimental.shard_map import shard_map
12
+ from typing import Dict, Callable, Sequence, Any, Union, Tuple, Type
13
+
14
+ from ..schedulers import NoiseScheduler
15
+ from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
16
+ from ..samplers.common import DiffusionSampler
17
+
18
+ from flaxdiff.utils import RandomMarkovState
19
+
20
+ from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
21
+
22
+ from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
23
+ from flax.training import dynamic_scale as dynamic_scale_lib
24
+ from flaxdiff.utils import TextEncoder, ConditioningEncoder
25
+
26
+ class TrainState(SimpleTrainState):
27
+ rngs: jax.random.PRNGKey
28
+ ema_params: dict
29
+
30
+ def apply_ema(self, decay: float = 0.999):
31
+ new_ema_params = jax.tree_util.tree_map(
32
+ lambda ema, param: decay * ema + (1 - decay) * param,
33
+ self.ema_params,
34
+ self.params,
35
+ )
36
+ return self.replace(ema_params=new_ema_params)
37
+
38
+ from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
39
+
40
+ class DiffusionTrainer(SimpleTrainer):
41
+ noise_schedule: NoiseScheduler
42
+ model_output_transform: DiffusionPredictionTransform
43
+ ema_decay: float = 0.999
44
+
45
+ def __init__(self,
46
+ model: nn.Module,
47
+ input_shapes: Dict[str, Tuple[int]],
48
+ optimizer: optax.GradientTransformation,
49
+ noise_schedule: NoiseScheduler,
50
+ rngs: jax.random.PRNGKey,
51
+ unconditional_prob: float = 0.12,
52
+ name: str = "Diffusion",
53
+ model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
54
+ autoencoder: AutoEncoder = None,
55
+ encoder: ConditioningEncoder = None,
56
+ **kwargs
57
+ ):
58
+ super().__init__(
59
+ model=model,
60
+ input_shapes=input_shapes,
61
+ optimizer=optimizer,
62
+ rngs=rngs,
63
+ name=name,
64
+ **kwargs
65
+ )
66
+ self.noise_schedule = noise_schedule
67
+ self.model_output_transform = model_output_transform
68
+ self.unconditional_prob = unconditional_prob
69
+
70
+ self.autoencoder = autoencoder
71
+ self.encoder = encoder
72
+
73
+ def generate_states(
74
+ self,
75
+ optimizer: optax.GradientTransformation,
76
+ rngs: jax.random.PRNGKey,
77
+ existing_state: dict = None,
78
+ existing_best_state: dict = None,
79
+ model: nn.Module = None,
80
+ param_transforms: Callable = None,
81
+ use_dynamic_scale: bool = False
82
+ ) -> Tuple[TrainState, TrainState]:
83
+ print("Generating states for DiffusionTrainer")
84
+ rngs, subkey = jax.random.split(rngs)
85
+
86
+ if existing_state == None:
87
+ input_vars = self.get_input_ones()
88
+ params = model.init(subkey, **input_vars)
89
+ new_state = {"params": params, "ema_params": params}
90
+ else:
91
+ new_state = existing_state
92
+
93
+ if param_transforms is not None:
94
+ params = param_transforms(params)
95
+
96
+ state = TrainState.create(
97
+ apply_fn=model.apply,
98
+ params=new_state['params'],
99
+ ema_params=new_state['ema_params'],
100
+ tx=optimizer,
101
+ rngs=rngs,
102
+ metrics=Metrics.empty(),
103
+ dynamic_scale = dynamic_scale_lib.DynamicScale() if use_dynamic_scale else None
104
+ )
105
+
106
+ if existing_best_state is not None:
107
+ best_state = state.replace(
108
+ params=existing_best_state['params'], ema_params=existing_best_state['ema_params'])
109
+ else:
110
+ best_state = state
111
+
112
+ return state, best_state
113
+
114
+ def _define_train_step(self, batch_size):
115
+ noise_schedule: NoiseScheduler = self.noise_schedule
116
+ model = self.model
117
+ model_output_transform = self.model_output_transform
118
+ loss_fn = self.loss_fn
119
+ unconditional_prob = self.unconditional_prob
120
+
121
+ # Determine the number of unconditional samples
122
+ num_unconditional = int(batch_size * unconditional_prob)
123
+
124
+ null_labels_full = self.encoder([""])
125
+ null_labels_seq = jnp.array(null_labels_full[0], dtype=jnp.float16)
126
+
127
+ conditioning_encoder = self.encoder
128
+
129
+ nS, nC = null_labels_seq.shape
130
+ null_labels_seq = jnp.broadcast_to(
131
+ null_labels_seq, (batch_size, nS, nC))
132
+
133
+ distributed_training = self.distributed_training
134
+
135
+ autoencoder = self.autoencoder
136
+
137
+ # @jax.jit
138
+ def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, local_device_index):
139
+ """Train for a single step."""
140
+ rng_state, subkey = rng_state.get_random_key()
141
+ subkey = jax.random.fold_in(subkey, local_device_index.reshape())
142
+ local_rng_state = RandomMarkovState(subkey)
143
+
144
+ images = batch['image']
145
+
146
+ # First get the standard deviation of the images
147
+ # std = jnp.std(images, axis=(1, 2, 3))
148
+ # is_non_zero = (std > 0)
149
+
150
+ images = jnp.array(images, dtype=jnp.float32)
151
+ # normalize image
152
+ images = (images - 127.5) / 127.5
153
+
154
+ if autoencoder is not None:
155
+ # Convert the images to latent space
156
+ local_rng_state, rngs = local_rng_state.get_random_key()
157
+ images = autoencoder.encode(images, rngs)
158
+
159
+ label_seq = conditioning_encoder.encode_from_tokens(batch)
160
+
161
+ # Generate random probabilities to decide how much of this batch will be unconditional
162
+
163
+ label_seq = jnp.concat(
164
+ [null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
165
+
166
+ noise_level, local_rng_state = noise_schedule.generate_timesteps(images.shape[0], local_rng_state)
167
+
168
+ local_rng_state, rngs = local_rng_state.get_random_key()
169
+ noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
170
+
171
+ rates = noise_schedule.get_rates(noise_level)
172
+ noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
173
+ images, noise, rates)
174
+
175
+ def model_loss(params):
176
+ preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
177
+ preds = model_output_transform.pred_transform(
178
+ noisy_images, preds, rates)
179
+ nloss = loss_fn(preds, expected_output)
180
+ # Ignore the loss contribution of images with zero standard deviation
181
+ nloss *= noise_schedule.get_weights(noise_level)
182
+ # nloss = jnp.mean(nloss, axis=(1,2,3))
183
+ # nloss = jnp.where(is_non_zero, nloss, 0)
184
+ # nloss = jnp.mean(nloss, where=nloss != 0)
185
+ nloss = jnp.mean(nloss)
186
+ loss = nloss
187
+ return loss
188
+
189
+
190
+ if train_state.dynamic_scale is not None:
191
+ # dynamic scale takes care of averaging gradients across replicas
192
+ grad_fn = train_state.dynamic_scale.value_and_grad(
193
+ model_loss, axis_name="data"
194
+ )
195
+ dynamic_scale, is_fin, loss, grads = grad_fn(train_state.params)
196
+ train_state = train_state.replace(dynamic_scale=dynamic_scale)
197
+ else:
198
+ grad_fn = jax.value_and_grad(model_loss)
199
+ loss, grads = grad_fn(train_state.params)
200
+ if distributed_training:
201
+ grads = jax.lax.pmean(grads, "data")
202
+
203
+ new_state = train_state.apply_gradients(grads=grads)
204
+
205
+ if train_state.dynamic_scale is not None:
206
+ # if is_fin == False the gradients contain Inf/NaNs and optimizer state and
207
+ # params should be restored (= skip this step).
208
+ select_fn = functools.partial(jnp.where, is_fin)
209
+ new_state = new_state.replace(
210
+ opt_state=jax.tree_util.tree_map(
211
+ select_fn, new_state.opt_state, train_state.opt_state
212
+ ),
213
+ params=jax.tree_util.tree_map(
214
+ select_fn, new_state.params, train_state.params
215
+ ),
216
+ )
217
+
218
+ train_state = new_state.apply_ema(self.ema_decay)
219
+
220
+ if distributed_training:
221
+ loss = jax.lax.pmean(loss, "data")
222
+ return train_state, loss, rng_state
223
+
224
+ if distributed_training:
225
+ train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')),
226
+ out_specs=(P(), P(), P()))
227
+ train_step = jax.jit(train_step)
228
+
229
+ return train_step
230
+
231
+ def _define_vaidation_step(self, sampler_class: Type[DiffusionSampler]):
232
+ model = self.model
233
+ encoder = self.encoder
234
+ autoencoder = self.autoencoder
235
+
236
+ null_labels_full = encoder([""])
237
+ null_labels_full = null_labels_full.astype(jnp.float16)
238
+ # null_labels_seq = jnp.array(null_labels_full[0], dtype=jnp.float16)
239
+
240
+ def generate_sampler(state: TrainState):
241
+ sampler = sampler_class(
242
+ model=model,
243
+ params=state.ema_params,
244
+ noise_schedule=self.noise_schedule,
245
+ model_output_transform=self.model_output_transform,
246
+ image_size=self.input_shapes['x'][0],
247
+ null_labels_seq=null_labels_full,
248
+ autoencoder=autoencoder,
249
+ )
250
+ return sampler
251
+
252
+ def generate_samples(
253
+ batch,
254
+ sampler: DiffusionSampler,
255
+ diffusion_steps: int,
256
+ ):
257
+ labels_seq = encoder.encode_from_tokens(batch)
258
+ labels_seq = jnp.array(labels_seq, dtype=jnp.float16)
259
+ samples = sampler.generate_images(
260
+ num_images=len(labels_seq),
261
+ diffusion_steps=diffusion_steps,
262
+ start_step=1000,
263
+ end_step=0,
264
+ priors=None,
265
+ model_conditioning_inputs=(labels_seq,),
266
+ )
267
+ return samples
268
+
269
+ return generate_sampler, generate_samples
270
+
271
+ def validation_loop(
272
+ self,
273
+ val_state: SimpleTrainState,
274
+ val_step_fn: Callable,
275
+ val_ds,
276
+ val_steps_per_epoch,
277
+ current_step,
278
+ diffusion_steps=200,
279
+ ):
280
+ generate_sampler, generate_samples = val_step_fn
281
+
282
+ sampler = generate_sampler(val_state)
283
+
284
+ val_ds = iter(val_ds()) if val_ds else None
285
+ # Evaluation step
286
+ try:
287
+ samples = generate_samples(
288
+ next(val_ds),
289
+ sampler,
290
+ diffusion_steps,
291
+ )
292
+
293
+ # Put each sample on wandb
294
+ if self.wandb:
295
+ import numpy as np
296
+ from wandb import Image as wandbImage
297
+ wandb_images = []
298
+ for i in range(samples.shape[0]):
299
+ # convert the sample to numpy
300
+ sample = np.array(samples[i])
301
+ # denormalize the image
302
+ sample = (sample + 1) * 127.5
303
+ sample = np.clip(sample, 0, 255).astype(np.uint8)
304
+ # add the image to the list
305
+ wandb_images.append(sample)
306
+ # log the images to wandb
307
+ self.wandb.log({
308
+ f"sample_{i}": wandbImage(sample, caption=f"Sample {i} at step {current_step}")
309
+ }, step=current_step)
310
+ except Exception as e:
311
+ print("Error logging images to wandb", e)
312
+ traceback.print_exc()
313
+
314
+ def fit(self, data, training_steps_per_epoch, epochs, val_steps_per_epoch=8, sampler_class=None):
315
+ local_batch_size = data['local_batch_size']
316
+ validation_step_args = {
317
+ "sampler_class": sampler_class,
318
+ }
319
+ super().fit(
320
+ data,
321
+ train_steps_per_epoch=training_steps_per_epoch,
322
+ epochs=epochs,
323
+ train_step_args={"batch_size": local_batch_size},
324
+ val_steps_per_epoch=val_steps_per_epoch,
325
+ validation_step_args=validation_step_args,
326
+ )