flaxdiff 0.1.35.5__py3-none-any.whl → 0.1.36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flaxdiff/samplers/ddpm.py CHANGED
@@ -3,9 +3,8 @@ import jax.numpy as jnp
3
3
  from .common import DiffusionSampler
4
4
  from ..utils import MarkovState, RandomMarkovState
5
5
  class DDPMSampler(DiffusionSampler):
6
- def take_next_step(self,
7
- current_samples, reconstructed_samples,
8
- pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
6
+ def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
7
+ pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
9
8
  mean = self.noise_schedule.get_posterior_mean(reconstructed_samples, current_samples, current_step)
10
9
  variance = self.noise_schedule.get_posterior_variance(steps=current_step)
11
10
 
@@ -19,9 +18,8 @@ class DDPMSampler(DiffusionSampler):
19
18
  return super().generate_images(num_images=num_images, diffusion_steps=diffusion_steps, start_step=start_step, *args, **kwargs)
20
19
 
21
20
  class SimpleDDPMSampler(DiffusionSampler):
22
- def take_next_step(self,
23
- current_samples, reconstructed_samples,
24
- pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
21
+ def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
22
+ pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
25
23
  state, rng = state.get_random_key()
26
24
  noise = jax.random.normal(rng, reconstructed_samples.shape, dtype=jnp.float32)
27
25
 
@@ -33,11 +31,7 @@ class SimpleDDPMSampler(DiffusionSampler):
33
31
 
34
32
  noise_ratio_squared = (next_noise_rate ** 2) / (current_noise_rate ** 2)
35
33
  signal_ratio_squared = (current_signal_rate ** 2) / (next_signal_rate ** 2)
36
- betas = (1 - signal_ratio_squared)
37
- gamma = jnp.sqrt(noise_ratio_squared * betas)
34
+ gamma = jnp.sqrt(noise_ratio_squared * (1 - signal_ratio_squared))
38
35
 
39
36
  next_samples = next_signal_rate * reconstructed_samples + pred_noise_coeff * pred_noise + noise * gamma
40
- # pred_noise_coeff = ((next_noise_rate ** 2) * current_signal_rate) / (current_noise_rate * next_signal_rate)
41
- # next_samples = (2 - jnp.sqrt(1 - betas)) * current_samples - betas * (pred_noise / current_noise_rate) + noise * gamma#jnp.sqrt(betas)
42
- # next_samples = (1 / (jnp.sqrt(1 - betas) + 1.e-24)) * (current_samples - betas * (pred_noise / current_noise_rate)) + noise * gamma
43
37
  return next_samples, state
@@ -5,9 +5,8 @@ from ..utils import RandomMarkovState
5
5
 
6
6
  class EulerSampler(DiffusionSampler):
7
7
  # Basically a DDIM Sampler but parameterized as an ODE
8
- def take_next_step(self,
9
- current_samples, reconstructed_samples,
10
- pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
8
+ def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
9
+ pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
11
10
  current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
12
11
  next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
13
12
 
@@ -22,9 +21,8 @@ class SimplifiedEulerSampler(DiffusionSampler):
22
21
  """
23
22
  This is for networks with forward diffusion of the form x_{t+1} = x_t + sigma_t * epsilon_t
24
23
  """
25
- def take_next_step(self,
26
- current_samples, reconstructed_samples,
27
- pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
24
+ def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
25
+ pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
28
26
  _, current_sigma = self.noise_schedule.get_rates(current_step)
29
27
  _, next_sigma = self.noise_schedule.get_rates(next_step)
30
28
 
@@ -38,9 +36,8 @@ class EulerAncestralSampler(DiffusionSampler):
38
36
  """
39
37
  Similar to EulerSampler but with ancestral sampling
40
38
  """
41
- def take_next_step(self,
42
- current_samples, reconstructed_samples,
43
- pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
39
+ def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
40
+ pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
44
41
  current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
45
42
  next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
46
43
 
@@ -56,4 +53,4 @@ class EulerAncestralSampler(DiffusionSampler):
56
53
  dW = jax.random.normal(subkey, current_samples.shape) * sigma_up
57
54
 
58
55
  next_samples = current_samples + dx * dt + dW
59
- return next_samples, state
56
+ return next_samples, state
@@ -4,9 +4,8 @@ from .common import DiffusionSampler
4
4
  from ..utils import RandomMarkovState
5
5
 
6
6
  class HeunSampler(DiffusionSampler):
7
- def take_next_step(self,
8
- current_samples, reconstructed_samples,
9
- pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
7
+ def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
8
+ pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
10
9
  # Get the noise and signal rates for the current and next steps
11
10
  current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
12
11
  next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
@@ -18,7 +17,7 @@ class HeunSampler(DiffusionSampler):
18
17
  next_samples_0 = current_samples + dx_0 * dt
19
18
 
20
19
  # Recompute x_0 and eps at the first estimate to refine the derivative
21
- estimated_x_0, _, _ = self.sample_model(next_samples_0, next_step)
20
+ estimated_x_0, _, _ = self.sample_model(next_samples_0, next_step, *model_conditioning_inputs)
22
21
 
23
22
  # Estimate the refined derivative using the midpoint (Heun's method)
24
23
  dx_1 = (next_samples_0 - x_0_coeff * estimated_x_0) / next_sigma
@@ -8,9 +8,8 @@ class MultiStepDPM(DiffusionSampler):
8
8
  super().__init__(*args, **kwargs)
9
9
  self.history = []
10
10
 
11
- def _renoise(self,
12
- current_samples, reconstructed_samples,
13
- pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
11
+ def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
12
+ pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
14
13
  # Get the noise and signal rates for the current and next steps
15
14
  current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
16
15
  next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
@@ -1,7 +1,7 @@
1
1
  import jax
2
2
  import jax.numpy as jnp
3
3
  from .common import DiffusionSampler
4
- from ..utils import RandomMarkovState
4
+ from ..utils import RandomMarkovState, MarkovState
5
5
  from ..schedulers import GeneralizedNoiseScheduler
6
6
 
7
7
  class RK4Sampler(DiffusionSampler):
@@ -9,14 +9,14 @@ class RK4Sampler(DiffusionSampler):
9
9
  super().__init__(*args, **kwargs)
10
10
  assert issubclass(type(self.noise_schedule), GeneralizedNoiseScheduler), "Noise schedule must be a GeneralizedNoiseScheduler"
11
11
  @jax.jit
12
- def get_derivative(x_t, sigma, state:RandomMarkovState) -> tuple[jnp.ndarray, RandomMarkovState]:
12
+ def get_derivative(x_t, sigma, state:RandomMarkovState, model_conditioning_inputs) -> tuple[jnp.ndarray, RandomMarkovState]:
13
13
  t = self.noise_schedule.get_timesteps(sigma)
14
- x_0, eps, _ = self.sample_model(x_t, t)
14
+ x_0, eps, _ = self.sample_model(x_t, t, *model_conditioning_inputs)
15
15
  return eps, state
16
16
 
17
17
  self.get_derivative = get_derivative
18
18
 
19
- def sample_step(self, current_samples:jnp.ndarray, current_step, next_step, state:RandomMarkovState=None) -> tuple[jnp.ndarray, RandomMarkovState]:
19
+ def sample_step(self, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
20
20
  step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32)
21
21
  current_step = step_ones * current_step
22
22
  next_step = step_ones * next_step
@@ -25,10 +25,10 @@ class RK4Sampler(DiffusionSampler):
25
25
 
26
26
  dt = next_sigma - current_sigma
27
27
 
28
- k1, state = self.get_derivative(current_samples, current_sigma, state)
29
- k2, state = self.get_derivative(current_samples + 0.5 * k1 * dt, current_sigma + 0.5 * dt, state)
30
- k3, state = self.get_derivative(current_samples + 0.5 * k2 * dt, current_sigma + 0.5 * dt, state)
31
- k4, state = self.get_derivative(current_samples + k3 * dt, current_sigma + dt, state)
28
+ k1, state = self.get_derivative(current_samples, current_sigma, state, model_conditioning_inputs)
29
+ k2, state = self.get_derivative(current_samples + 0.5 * k1 * dt, current_sigma + 0.5 * dt, state, model_conditioning_inputs)
30
+ k3, state = self.get_derivative(current_samples + 0.5 * k2 * dt, current_sigma + 0.5 * dt, state, model_conditioning_inputs)
31
+ k4, state = self.get_derivative(current_samples + k3 * dt, current_sigma + dt, state, model_conditioning_inputs)
32
32
 
33
- next_samples = current_samples + ((k1 + 2 * k2 + 2 * k3 + k4) / 6) * dt
33
+ next_samples = current_samples + (((k1 + 2 * k2 + 2 * k3 + k4) * dt) / 6)
34
34
  return next_samples, state
@@ -14,7 +14,7 @@ from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransfor
14
14
  from flaxdiff.utils import RandomMarkovState
15
15
 
16
16
  from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
17
-
17
+ from .diffusion_trainer import TrainState
18
18
  from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
19
19
 
20
20
  class AutoEncoderTrainer(SimpleTrainer):
@@ -1,22 +1,27 @@
1
+ import flax
1
2
  from flax import linen as nn
2
3
  import jax
3
4
  from typing import Callable
4
5
  from dataclasses import field
5
6
  import jax.numpy as jnp
7
+ import traceback
6
8
  import optax
9
+ import functools
7
10
  from jax.sharding import Mesh, PartitionSpec as P
8
11
  from jax.experimental.shard_map import shard_map
9
- from typing import Dict, Callable, Sequence, Any, Union, Tuple
12
+ from typing import Dict, Callable, Sequence, Any, Union, Tuple, Type
10
13
 
11
14
  from ..schedulers import NoiseScheduler
12
15
  from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
16
+ from ..samplers.common import DiffusionSampler
13
17
 
14
18
  from flaxdiff.utils import RandomMarkovState
15
19
 
16
20
  from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
17
21
 
18
22
  from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
19
- from flax.training.dynamic_scale import DynamicScale
23
+ from flax.training import dynamic_scale as dynamic_scale_lib
24
+ from flaxdiff.utils import TextEncoder, ConditioningEncoder
20
25
 
21
26
  class TrainState(SimpleTrainState):
22
27
  rngs: jax.random.PRNGKey
@@ -47,6 +52,7 @@ class DiffusionTrainer(SimpleTrainer):
47
52
  name: str = "Diffusion",
48
53
  model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
49
54
  autoencoder: AutoEncoder = None,
55
+ encoder: ConditioningEncoder = None,
50
56
  **kwargs
51
57
  ):
52
58
  super().__init__(
@@ -62,6 +68,7 @@ class DiffusionTrainer(SimpleTrainer):
62
68
  self.unconditional_prob = unconditional_prob
63
69
 
64
70
  self.autoencoder = autoencoder
71
+ self.encoder = encoder
65
72
 
66
73
  def generate_states(
67
74
  self,
@@ -84,8 +91,7 @@ class DiffusionTrainer(SimpleTrainer):
84
91
  new_state = existing_state
85
92
 
86
93
  if param_transforms is not None:
87
- new_state['params'] = param_transforms(new_state['params'])
88
- new_state['ema_params'] = param_transforms(new_state['ema_params'])
94
+ params = param_transforms(params)
89
95
 
90
96
  state = TrainState.create(
91
97
  apply_fn=model.apply,
@@ -94,7 +100,7 @@ class DiffusionTrainer(SimpleTrainer):
94
100
  tx=optimizer,
95
101
  rngs=rngs,
96
102
  metrics=Metrics.empty(),
97
- dynamic_scale = DynamicScale() if use_dynamic_scale else None
103
+ dynamic_scale = dynamic_scale_lib.DynamicScale() if use_dynamic_scale else None
98
104
  )
99
105
 
100
106
  if existing_best_state is not None:
@@ -105,7 +111,7 @@ class DiffusionTrainer(SimpleTrainer):
105
111
 
106
112
  return state, best_state
107
113
 
108
- def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
114
+ def _define_train_step(self, batch_size):
109
115
  noise_schedule: NoiseScheduler = self.noise_schedule
110
116
  model = self.model
111
117
  model_output_transform = self.model_output_transform
@@ -114,6 +120,11 @@ class DiffusionTrainer(SimpleTrainer):
114
120
 
115
121
  # Determine the number of unconditional samples
116
122
  num_unconditional = int(batch_size * unconditional_prob)
123
+
124
+ null_labels_full = self.encoder([""])
125
+ null_labels_seq = jnp.array(null_labels_full[0], dtype=jnp.float16)
126
+
127
+ conditioning_encoder = self.encoder
117
128
 
118
129
  nS, nC = null_labels_seq.shape
119
130
  null_labels_seq = jnp.broadcast_to(
@@ -131,6 +142,11 @@ class DiffusionTrainer(SimpleTrainer):
131
142
  local_rng_state = RandomMarkovState(subkey)
132
143
 
133
144
  images = batch['image']
145
+
146
+ # First get the standard deviation of the images
147
+ # std = jnp.std(images, axis=(1, 2, 3))
148
+ # is_non_zero = (std > 0)
149
+
134
150
  images = jnp.array(images, dtype=jnp.float32)
135
151
  # normalize image
136
152
  images = (images - 127.5) / 127.5
@@ -140,9 +156,7 @@ class DiffusionTrainer(SimpleTrainer):
140
156
  local_rng_state, rngs = local_rng_state.get_random_key()
141
157
  images = autoencoder.encode(images, rngs)
142
158
 
143
- output = text_embedder(
144
- input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
145
- label_seq = output.last_hidden_state
159
+ label_seq = conditioning_encoder.encode_from_tokens(batch)
146
160
 
147
161
  # Generate random probabilities to decide how much of this batch will be unconditional
148
162
 
@@ -163,8 +177,11 @@ class DiffusionTrainer(SimpleTrainer):
163
177
  preds = model_output_transform.pred_transform(
164
178
  noisy_images, preds, rates)
165
179
  nloss = loss_fn(preds, expected_output)
166
- # nloss = jnp.mean(nloss, axis=1)
180
+ # Ignore the loss contribution of images with zero standard deviation
167
181
  nloss *= noise_schedule.get_weights(noise_level)
182
+ # nloss = jnp.mean(nloss, axis=(1,2,3))
183
+ # nloss = jnp.where(is_non_zero, nloss, 0)
184
+ # nloss = jnp.mean(nloss, where=nloss != 0)
168
185
  nloss = jnp.mean(nloss)
169
186
  loss = nloss
170
187
  return loss
@@ -185,11 +202,11 @@ class DiffusionTrainer(SimpleTrainer):
185
202
 
186
203
  new_state = train_state.apply_gradients(grads=grads)
187
204
 
188
- if train_state.dynamic_scale:
205
+ if train_state.dynamic_scale is not None:
189
206
  # if is_fin == False the gradients contain Inf/NaNs and optimizer state and
190
207
  # params should be restored (= skip this step).
191
208
  select_fn = functools.partial(jnp.where, is_fin)
192
- new_state = train_state.replace(
209
+ new_state = new_state.replace(
193
210
  opt_state=jax.tree_util.tree_map(
194
211
  select_fn, new_state.opt_state, train_state.opt_state
195
212
  ),
@@ -211,24 +228,99 @@ class DiffusionTrainer(SimpleTrainer):
211
228
 
212
229
  return train_step
213
230
 
214
- def _define_compute_metrics(self):
215
- @jax.jit
216
- def compute_metrics(state: TrainState, expected, pred):
217
- loss = jnp.mean(jnp.square(pred - expected))
218
- metric_updates = state.metrics.single_from_model_output(loss=loss)
219
- metrics = state.metrics.merge(metric_updates)
220
- state = state.replace(metrics=metrics)
221
- return state
222
- return compute_metrics
223
-
224
- def fit(self, data, steps_per_epoch, epochs):
225
- null_labels_full = data['null_labels_full']
231
+ def _define_vaidation_step(self, sampler_class: Type[DiffusionSampler]):
232
+ model = self.model
233
+ encoder = self.encoder
234
+ autoencoder = self.autoencoder
235
+
236
+ null_labels_full = encoder([""])
237
+ null_labels_full = null_labels_full.astype(jnp.float16)
238
+ # null_labels_seq = jnp.array(null_labels_full[0], dtype=jnp.float16)
239
+
240
+ def generate_sampler(state: TrainState):
241
+ sampler = sampler_class(
242
+ model=model,
243
+ params=state.ema_params,
244
+ noise_schedule=self.noise_schedule,
245
+ model_output_transform=self.model_output_transform,
246
+ image_size=self.input_shapes['x'][0],
247
+ null_labels_seq=null_labels_full,
248
+ autoencoder=autoencoder,
249
+ )
250
+ return sampler
251
+
252
+ def generate_samples(
253
+ batch,
254
+ sampler: DiffusionSampler,
255
+ diffusion_steps: int,
256
+ ):
257
+ labels_seq = encoder.encode_from_tokens(batch)
258
+ labels_seq = jnp.array(labels_seq, dtype=jnp.float16)
259
+ samples = sampler.generate_images(
260
+ num_images=len(labels_seq),
261
+ diffusion_steps=diffusion_steps,
262
+ start_step=1000,
263
+ end_step=0,
264
+ priors=None,
265
+ model_conditioning_inputs=(labels_seq,),
266
+ )
267
+ return samples
268
+
269
+ return generate_sampler, generate_samples
270
+
271
+ def validation_loop(
272
+ self,
273
+ val_state: SimpleTrainState,
274
+ val_step_fn: Callable,
275
+ val_ds,
276
+ val_steps_per_epoch,
277
+ current_step,
278
+ diffusion_steps=200,
279
+ ):
280
+ generate_sampler, generate_samples = val_step_fn
281
+
282
+ sampler = generate_sampler(val_state)
283
+
284
+ val_ds = iter(val_ds()) if val_ds else None
285
+ # Evaluation step
286
+ try:
287
+ samples = generate_samples(
288
+ next(val_ds),
289
+ sampler,
290
+ diffusion_steps,
291
+ )
292
+
293
+ # Put each sample on wandb
294
+ if self.wandb:
295
+ import numpy as np
296
+ from wandb import Image as wandbImage
297
+ wandb_images = []
298
+ for i in range(samples.shape[0]):
299
+ # convert the sample to numpy
300
+ sample = np.array(samples[i])
301
+ # denormalize the image
302
+ sample = (sample + 1) * 127.5
303
+ sample = np.clip(sample, 0, 255).astype(np.uint8)
304
+ # add the image to the list
305
+ wandb_images.append(sample)
306
+ # log the images to wandb
307
+ self.wandb.log({
308
+ f"sample_{i}": wandbImage(sample, caption=f"Sample {i} at step {current_step}")
309
+ }, step=current_step)
310
+ except Exception as e:
311
+ print("Error logging images to wandb", e)
312
+ traceback.print_exc()
313
+
314
+ def fit(self, data, training_steps_per_epoch, epochs, val_steps_per_epoch=8, sampler_class=None):
226
315
  local_batch_size = data['local_batch_size']
227
- text_embedder = data['model']
228
- super().fit(data, steps_per_epoch, epochs, {
229
- "batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder})
230
-
231
- def boolean_string(s):
232
- if type(s) == bool:
233
- return s
234
- return s == 'True'
316
+ validation_step_args = {
317
+ "sampler_class": sampler_class,
318
+ }
319
+ super().fit(
320
+ data,
321
+ train_steps_per_epoch=training_steps_per_epoch,
322
+ epochs=epochs,
323
+ train_step_args={"batch_size": local_batch_size},
324
+ val_steps_per_epoch=val_steps_per_epoch,
325
+ validation_step_args=validation_step_args,
326
+ )