flaxdiff 0.1.36.4__tar.gz → 0.1.37__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/PKG-INFO +60 -22
  2. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/README.md +59 -21
  3. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/samplers/common.py +20 -13
  4. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/samplers/ddim.py +1 -1
  5. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/samplers/ddpm.py +2 -2
  6. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/samplers/euler.py +3 -3
  7. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/samplers/heun_sampler.py +2 -2
  8. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/samplers/multistep_dpm.py +1 -1
  9. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/samplers/rk4_sampler.py +7 -7
  10. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/schedulers/__init__.py +1 -1
  11. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/schedulers/cosine.py +1 -1
  12. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/trainer/diffusion_trainer.py +16 -14
  13. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff.egg-info/PKG-INFO +60 -22
  14. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/pyproject.toml +1 -1
  15. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/__init__.py +0 -0
  16. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/data/__init__.py +0 -0
  17. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/data/dataset_map.py +0 -0
  18. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/data/datasets.py +0 -0
  19. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/data/online_loader.py +0 -0
  20. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/data/sources/gcs.py +0 -0
  21. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/data/sources/tfds.py +0 -0
  22. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/metrics/inception.py +0 -0
  23. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/metrics/utils.py +0 -0
  24. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/models/__init__.py +0 -0
  25. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/models/attention.py +0 -0
  26. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/models/autoencoder/__init__.py +0 -0
  27. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
  28. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/models/autoencoder/diffusers.py +0 -0
  29. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
  30. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/models/common.py +0 -0
  31. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/models/favor_fastattn.py +0 -0
  32. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/models/simple_unet.py +0 -0
  33. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/models/simple_vit.py +0 -0
  34. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/predictors/__init__.py +0 -0
  35. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/samplers/__init__.py +0 -0
  36. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/schedulers/common.py +0 -0
  37. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/schedulers/continuous.py +0 -0
  38. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/schedulers/discrete.py +0 -0
  39. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/schedulers/exp.py +0 -0
  40. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/schedulers/karras.py +0 -0
  41. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/schedulers/linear.py +0 -0
  42. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/schedulers/sqrt.py +0 -0
  43. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/trainer/__init__.py +0 -0
  44. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
  45. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/trainer/simple_trainer.py +0 -0
  46. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/trainer/video_diffusion_trainer.py +0 -0
  47. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/utils.py +0 -0
  48. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff.egg-info/SOURCES.txt +0 -0
  49. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff.egg-info/dependency_links.txt +0 -0
  50. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff.egg-info/requires.txt +0 -0
  51. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff.egg-info/top_level.txt +0 -0
  52. {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.36.4
3
+ Version: 0.1.37
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -96,7 +96,7 @@ Also, few of the text may be generated with help of github copilot, so please ex
96
96
  ### Schedulers
97
97
  Implemented in `flaxdiff.schedulers`:
98
98
  - **LinearNoiseSchedule** (`flaxdiff.schedulers.LinearNoiseSchedule`): A beta-parameterized discrete scheduler.
99
- - **CosineNoiseSchedule** (`flaxdiff.schedulers.CosineNoiseSchedule`): A beta-parameterized discrete scheduler.
99
+ - **CosineNoiseScheduler** (`flaxdiff.schedulers.CosineNoiseScheduler`): A beta-parameterized discrete scheduler.
100
100
  - **ExpNoiseSchedule** (`flaxdiff.schedulers.ExpNoiseSchedule`): A beta-parameterized discrete scheduler.
101
101
  - **CosineContinuousNoiseScheduler** (`flaxdiff.schedulers.CosineContinuousNoiseScheduler`): A continuous scheduler.
102
102
  - **CosineGeneralNoiseScheduler** (`flaxdiff.schedulers.CosineGeneralNoiseScheduler`): A continuous sigma parameterized cosine scheduler.
@@ -147,43 +147,81 @@ sticking to the versions mentioned in the requirements.txt
147
147
  Here is a simplified example to get you started with training a diffusion model using FlaxDiff:
148
148
 
149
149
  ```python
150
- from flaxdiff.schedulers import EDMNoiseScheduler
150
+ from flaxdiff.schedulers import EDMNoiseScheduler, KarrasVENoiseScheduler
151
151
  from flaxdiff.predictors import KarrasPredictionTransform
152
- from flaxdiff.models.simple_unet import SimpleUNet as UNet
152
+ from flaxdiff.models.simple_unet import Unet
153
153
  from flaxdiff.trainer import DiffusionTrainer
154
+ from flaxdiff.data.datasets import get_dataset_grain
155
+ from flaxdiff.utils import defaultTextEncodeModel
156
+ from flaxdiff.samplers.euler import EulerAncestralSampler
154
157
  import jax
158
+ import jax.numpy as jnp
155
159
  import optax
156
160
  from datetime import datetime
157
161
 
158
162
  BATCH_SIZE = 16
159
- IMAGE_SIZE = 64
163
+ IMAGE_SIZE = 128
160
164
 
161
165
  # Define noise scheduler
162
166
  edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
163
-
167
+ karas_ve_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
164
168
  # Define model
165
- unet = UNet(emb_features=256,
166
- feature_depths=[64, 128, 256, 512],
167
- attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}],
169
+ unet = Unet(emb_features=256,
170
+ feature_depths=[64, 64, 128, 256, 512],
171
+ attention_configs=[
172
+ None,
173
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
174
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
175
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
176
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":False, "use_self_and_cross":False}
177
+ ],
168
178
  num_res_blocks=2,
169
- num_middle_res_blocks=1)
170
-
179
+ num_middle_res_blocks=1
180
+ )
171
181
  # Load dataset
172
- data, datalen = get_dataset("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
182
+ data = get_dataset_grain("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
183
+ datalen = data['train_len']
173
184
  batches = datalen // BATCH_SIZE
174
185
 
186
+ input_shapes = {
187
+ "x": (IMAGE_SIZE, IMAGE_SIZE, 3),
188
+ "temb": (),
189
+ "textcontext": (77, 768)
190
+ }
191
+ text_encoder = defaultTextEncodeModel()
192
+
193
+ # Construct a validation set by the prompts
194
+ val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold']
195
+
196
+ def get_val_dataset(batch_size=8):
197
+ for i in range(0, len(val_prompts), batch_size):
198
+ prompts = val_prompts[i:i + batch_size]
199
+ tokens = text_encoder.tokenize(prompts)
200
+ yield tokens
201
+
202
+ data['test'] = get_val_dataset
203
+ data['test_len'] = len(val_prompts)
204
+
175
205
  # Define optimizer
176
206
  solver = optax.adam(2e-4)
177
207
 
178
208
  # Create trainer
179
- trainer = DiffusionTrainer(unet, optimizer=solver,
180
- noise_schedule=edm_schedule,
181
- rngs=jax.random.PRNGKey(4),
182
- name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
183
- model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data))
209
+ trainer = DiffusionTrainer(
210
+ unet, optimizer=solver,
211
+ input_shapes=input_shapes,
212
+ noise_schedule=edm_schedule,
213
+ rngs=jax.random.PRNGKey(4),
214
+ name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
215
+ model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data),
216
+ encoder=text_encoder,
217
+ distributed_training=True,
218
+ wandb_config = {
219
+ "project": 'mlops-msml605-project',
220
+ "name": f"prototype-{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}",
221
+ })
184
222
 
185
223
  # Train the model
186
- final_state = trainer.fit(data, batches, epochs=2000)
224
+ final_state = trainer.fit(data, batches, epochs=2000, sampler_class=EulerAncestralSampler, sampling_noise_schedule=karas_ve_schedule)
187
225
  ```
188
226
 
189
227
  ### Inference Example
@@ -301,8 +339,8 @@ Images generated by the following prompts using classifier free guidance with gu
301
339
  `Training Epochs: 1000`
302
340
  `Steps per epoch: 511`
303
341
 
304
- `Training Noise Schedule: CosineNoiseSchedule`
305
- `Inference Noise Schedule: CosineNoiseSchedule`
342
+ `Training Noise Schedule: CosineNoiseScheduler`
343
+ `Inference Noise Schedule: CosineNoiseScheduler`
306
344
 
307
345
  `Model: UNet(emb_features=256,
308
346
  feature_depths=[64, 128, 256, 512],
@@ -321,8 +359,8 @@ Images generated by the following prompts using classifier free guidance with gu
321
359
  `Training Epochs: 1000`
322
360
  `Steps per epoch: 511`
323
361
 
324
- `Training Noise Schedule: CosineNoiseSchedule`
325
- `Inference Noise Schedule: CosineNoiseSchedule`
362
+ `Training Noise Schedule: CosineNoiseScheduler`
363
+ `Inference Noise Schedule: CosineNoiseScheduler`
326
364
 
327
365
  `Model: UNet(emb_features=256,
328
366
  feature_depths=[64, 128, 256, 512],
@@ -74,7 +74,7 @@ Also, few of the text may be generated with help of github copilot, so please ex
74
74
  ### Schedulers
75
75
  Implemented in `flaxdiff.schedulers`:
76
76
  - **LinearNoiseSchedule** (`flaxdiff.schedulers.LinearNoiseSchedule`): A beta-parameterized discrete scheduler.
77
- - **CosineNoiseSchedule** (`flaxdiff.schedulers.CosineNoiseSchedule`): A beta-parameterized discrete scheduler.
77
+ - **CosineNoiseScheduler** (`flaxdiff.schedulers.CosineNoiseScheduler`): A beta-parameterized discrete scheduler.
78
78
  - **ExpNoiseSchedule** (`flaxdiff.schedulers.ExpNoiseSchedule`): A beta-parameterized discrete scheduler.
79
79
  - **CosineContinuousNoiseScheduler** (`flaxdiff.schedulers.CosineContinuousNoiseScheduler`): A continuous scheduler.
80
80
  - **CosineGeneralNoiseScheduler** (`flaxdiff.schedulers.CosineGeneralNoiseScheduler`): A continuous sigma parameterized cosine scheduler.
@@ -125,43 +125,81 @@ sticking to the versions mentioned in the requirements.txt
125
125
  Here is a simplified example to get you started with training a diffusion model using FlaxDiff:
126
126
 
127
127
  ```python
128
- from flaxdiff.schedulers import EDMNoiseScheduler
128
+ from flaxdiff.schedulers import EDMNoiseScheduler, KarrasVENoiseScheduler
129
129
  from flaxdiff.predictors import KarrasPredictionTransform
130
- from flaxdiff.models.simple_unet import SimpleUNet as UNet
130
+ from flaxdiff.models.simple_unet import Unet
131
131
  from flaxdiff.trainer import DiffusionTrainer
132
+ from flaxdiff.data.datasets import get_dataset_grain
133
+ from flaxdiff.utils import defaultTextEncodeModel
134
+ from flaxdiff.samplers.euler import EulerAncestralSampler
132
135
  import jax
136
+ import jax.numpy as jnp
133
137
  import optax
134
138
  from datetime import datetime
135
139
 
136
140
  BATCH_SIZE = 16
137
- IMAGE_SIZE = 64
141
+ IMAGE_SIZE = 128
138
142
 
139
143
  # Define noise scheduler
140
144
  edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
141
-
145
+ karas_ve_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
142
146
  # Define model
143
- unet = UNet(emb_features=256,
144
- feature_depths=[64, 128, 256, 512],
145
- attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}],
147
+ unet = Unet(emb_features=256,
148
+ feature_depths=[64, 64, 128, 256, 512],
149
+ attention_configs=[
150
+ None,
151
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
152
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
153
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
154
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":False, "use_self_and_cross":False}
155
+ ],
146
156
  num_res_blocks=2,
147
- num_middle_res_blocks=1)
148
-
157
+ num_middle_res_blocks=1
158
+ )
149
159
  # Load dataset
150
- data, datalen = get_dataset("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
160
+ data = get_dataset_grain("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
161
+ datalen = data['train_len']
151
162
  batches = datalen // BATCH_SIZE
152
163
 
164
+ input_shapes = {
165
+ "x": (IMAGE_SIZE, IMAGE_SIZE, 3),
166
+ "temb": (),
167
+ "textcontext": (77, 768)
168
+ }
169
+ text_encoder = defaultTextEncodeModel()
170
+
171
+ # Construct a validation set by the prompts
172
+ val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold']
173
+
174
+ def get_val_dataset(batch_size=8):
175
+ for i in range(0, len(val_prompts), batch_size):
176
+ prompts = val_prompts[i:i + batch_size]
177
+ tokens = text_encoder.tokenize(prompts)
178
+ yield tokens
179
+
180
+ data['test'] = get_val_dataset
181
+ data['test_len'] = len(val_prompts)
182
+
153
183
  # Define optimizer
154
184
  solver = optax.adam(2e-4)
155
185
 
156
186
  # Create trainer
157
- trainer = DiffusionTrainer(unet, optimizer=solver,
158
- noise_schedule=edm_schedule,
159
- rngs=jax.random.PRNGKey(4),
160
- name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
161
- model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data))
187
+ trainer = DiffusionTrainer(
188
+ unet, optimizer=solver,
189
+ input_shapes=input_shapes,
190
+ noise_schedule=edm_schedule,
191
+ rngs=jax.random.PRNGKey(4),
192
+ name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
193
+ model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data),
194
+ encoder=text_encoder,
195
+ distributed_training=True,
196
+ wandb_config = {
197
+ "project": 'mlops-msml605-project',
198
+ "name": f"prototype-{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}",
199
+ })
162
200
 
163
201
  # Train the model
164
- final_state = trainer.fit(data, batches, epochs=2000)
202
+ final_state = trainer.fit(data, batches, epochs=2000, sampler_class=EulerAncestralSampler, sampling_noise_schedule=karas_ve_schedule)
165
203
  ```
166
204
 
167
205
  ### Inference Example
@@ -279,8 +317,8 @@ Images generated by the following prompts using classifier free guidance with gu
279
317
  `Training Epochs: 1000`
280
318
  `Steps per epoch: 511`
281
319
 
282
- `Training Noise Schedule: CosineNoiseSchedule`
283
- `Inference Noise Schedule: CosineNoiseSchedule`
320
+ `Training Noise Schedule: CosineNoiseScheduler`
321
+ `Inference Noise Schedule: CosineNoiseScheduler`
284
322
 
285
323
  `Model: UNet(emb_features=256,
286
324
  feature_depths=[64, 128, 256, 512],
@@ -299,8 +337,8 @@ Images generated by the following prompts using classifier free guidance with gu
299
337
  `Training Epochs: 1000`
300
338
  `Steps per epoch: 511`
301
339
 
302
- `Training Noise Schedule: CosineNoiseSchedule`
303
- `Inference Noise Schedule: CosineNoiseSchedule`
340
+ `Training Noise Schedule: CosineNoiseScheduler`
341
+ `Inference Noise Schedule: CosineNoiseScheduler`
304
342
 
305
343
  `Model: UNet(emb_features=256,
306
344
  feature_depths=[64, 128, 256, 512],
@@ -37,7 +37,7 @@ class DiffusionSampler():
37
37
  # Classifier free guidance
38
38
  assert null_labels_seq is not None, "Null labels sequence is required for classifier-free guidance"
39
39
  print("Using classifier-free guidance")
40
- def sample_model(x_t, t, *additional_inputs):
40
+ def sample_model(params, x_t, t, *additional_inputs):
41
41
  # Concatenate unconditional and conditional inputs
42
42
  x_t_cat = jnp.concatenate([x_t] * 2, axis=0)
43
43
  t_cat = jnp.concatenate([t] * 2, axis=0)
@@ -46,7 +46,7 @@ class DiffusionSampler():
46
46
 
47
47
  text_labels_seq, = additional_inputs
48
48
  text_labels_seq = jnp.concatenate([text_labels_seq, jnp.broadcast_to(null_labels_seq, text_labels_seq.shape)], axis=0)
49
- model_output = self.model.apply(self.params, *self.noise_schedule.transform_inputs(x_t_cat * c_in_cat, t_cat), text_labels_seq)
49
+ model_output = self.model.apply(params, *self.noise_schedule.transform_inputs(x_t_cat * c_in_cat, t_cat), text_labels_seq)
50
50
  # Split model output into unconditional and conditional parts
51
51
  model_output_cond, model_output_uncond = jnp.split(model_output, 2, axis=0)
52
52
  model_output = model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond)
@@ -55,10 +55,10 @@ class DiffusionSampler():
55
55
  return x_0, eps, model_output
56
56
  else:
57
57
  # Unconditional sampling
58
- def sample_model(x_t, t, *additional_inputs):
58
+ def sample_model(params, x_t, t, *additional_inputs):
59
59
  rates = self.noise_schedule.get_rates(t)
60
60
  c_in = self.model_output_transform.get_input_scale(rates)
61
- model_output = self.model.apply(self.params, *self.noise_schedule.transform_inputs(x_t * c_in, t), *additional_inputs)
61
+ model_output = self.model.apply(params, *self.noise_schedule.transform_inputs(x_t * c_in, t), *additional_inputs)
62
62
  x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule)
63
63
  return x_0, eps, model_output
64
64
 
@@ -70,22 +70,23 @@ class DiffusionSampler():
70
70
  self.sample_model = sample_model
71
71
 
72
72
  # Used to sample from the diffusion model
73
- def sample_step(self, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
73
+ def sample_step(self, sample_model_fn, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
74
74
  # First clip the noisy images
75
75
  step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32)
76
76
  current_step = step_ones * current_step
77
77
  next_step = step_ones * next_step
78
- pred_images, pred_noise, _ = self.sample_model(current_samples, current_step, *model_conditioning_inputs)
78
+ pred_images, pred_noise, _ = sample_model_fn(current_samples, current_step, *model_conditioning_inputs)
79
79
  # plotImages(pred_images)
80
80
  # pred_images = clip_images(pred_images)
81
81
  new_samples, state = self.take_next_step(current_samples=current_samples, reconstructed_samples=pred_images,
82
- pred_noise=pred_noise, current_step=current_step, next_step=next_step, state=state,
83
- model_conditioning_inputs=model_conditioning_inputs
82
+ pred_noise=pred_noise, current_step=current_step, next_step=next_step, state=state,
83
+ model_conditioning_inputs=model_conditioning_inputs,
84
+ sample_model_fn=sample_model_fn,
84
85
  )
85
86
  return new_samples, state
86
87
 
87
88
  def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
88
- pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
89
+ pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1,) -> tuple[jnp.ndarray, RandomMarkovState]:
89
90
  # estimate the q(x_{t-1} | x_t, x_0).
90
91
  # pred_images is x_0, noisy_images is x_t, steps is t
91
92
  return NotImplementedError
@@ -114,6 +115,7 @@ class DiffusionSampler():
114
115
  return jax.random.normal(rngs, (num_images, image_size, image_size, image_channels)) * variance
115
116
 
116
117
  def generate_images(self,
118
+ params:dict=None,
117
119
  num_images=16,
118
120
  diffusion_steps=1000,
119
121
  start_step:int = None,
@@ -131,10 +133,15 @@ class DiffusionSampler():
131
133
  if self.autoencoder is not None:
132
134
  priors = self.autoencoder.encode(priors)
133
135
  samples = priors
136
+
137
+ params = params if params is not None else self.params
138
+
139
+ def sample_model_fn(x_t, t, *additional_inputs):
140
+ return self.sample_model(params, x_t, t, *additional_inputs)
134
141
 
135
142
  # @jax.jit
136
- def sample_step(state:RandomMarkovState, samples, current_step, next_step):
137
- samples, state = self.sample_step(current_samples=samples,
143
+ def sample_step(sample_model_fn, state:RandomMarkovState, samples, current_step, next_step):
144
+ samples, state = self.sample_step(sample_model_fn=sample_model_fn, current_samples=samples,
138
145
  current_step=current_step,
139
146
  model_conditioning_inputs=model_conditioning_inputs,
140
147
  state=state, next_step=next_step)
@@ -154,11 +161,11 @@ class DiffusionSampler():
154
161
  next_step = self.scale_steps(steps[i+1] if i+1 < len(steps) else 0)
155
162
  if i != len(steps) - 1:
156
163
  # print("normal step")
157
- samples, rngstate = sample_step(rngstate, samples, current_step, next_step)
164
+ samples, rngstate = sample_step(sample_model_fn, rngstate, samples, current_step, next_step)
158
165
  else:
159
166
  # print("last step")
160
167
  step_ones = jnp.ones((num_images, ), dtype=jnp.int32)
161
- samples, _, _ = self.sample_model(samples, current_step * step_ones, *model_conditioning_inputs)
168
+ samples, _, _ = sample_model_fn(samples, current_step * step_ones, *model_conditioning_inputs)
162
169
  if self.autoencoder is not None:
163
170
  samples = self.autoencoder.decode(samples)
164
171
  samples = clip_images(samples)
@@ -4,7 +4,7 @@ from ..utils import MarkovState, RandomMarkovState
4
4
 
5
5
  class DDIMSampler(DiffusionSampler):
6
6
  def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
7
- pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
7
+ pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
8
8
  next_signal_rate, next_noise_rate = self.noise_schedule.get_rates(next_step)
9
9
  return reconstructed_samples * next_signal_rate + pred_noise * next_noise_rate, state
10
10
 
@@ -4,7 +4,7 @@ from .common import DiffusionSampler
4
4
  from ..utils import MarkovState, RandomMarkovState
5
5
  class DDPMSampler(DiffusionSampler):
6
6
  def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
7
- pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
7
+ pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
8
8
  mean = self.noise_schedule.get_posterior_mean(reconstructed_samples, current_samples, current_step)
9
9
  variance = self.noise_schedule.get_posterior_variance(steps=current_step)
10
10
 
@@ -19,7 +19,7 @@ class DDPMSampler(DiffusionSampler):
19
19
 
20
20
  class SimpleDDPMSampler(DiffusionSampler):
21
21
  def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
22
- pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
22
+ pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
23
23
  state, rng = state.get_random_key()
24
24
  noise = jax.random.normal(rng, reconstructed_samples.shape, dtype=jnp.float32)
25
25
 
@@ -6,7 +6,7 @@ from ..utils import RandomMarkovState
6
6
  class EulerSampler(DiffusionSampler):
7
7
  # Basically a DDIM Sampler but parameterized as an ODE
8
8
  def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
9
- pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
9
+ pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
10
10
  current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
11
11
  next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
12
12
 
@@ -22,7 +22,7 @@ class SimplifiedEulerSampler(DiffusionSampler):
22
22
  This is for networks with forward diffusion of the form x_{t+1} = x_t + sigma_t * epsilon_t
23
23
  """
24
24
  def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
25
- pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
25
+ pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
26
26
  _, current_sigma = self.noise_schedule.get_rates(current_step)
27
27
  _, next_sigma = self.noise_schedule.get_rates(next_step)
28
28
 
@@ -37,7 +37,7 @@ class EulerAncestralSampler(DiffusionSampler):
37
37
  Similar to EulerSampler but with ancestral sampling
38
38
  """
39
39
  def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
40
- pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
40
+ pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
41
41
  current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
42
42
  next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
43
43
 
@@ -5,7 +5,7 @@ from ..utils import RandomMarkovState
5
5
 
6
6
  class HeunSampler(DiffusionSampler):
7
7
  def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
8
- pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
8
+ pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
9
9
  # Get the noise and signal rates for the current and next steps
10
10
  current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
11
11
  next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
@@ -17,7 +17,7 @@ class HeunSampler(DiffusionSampler):
17
17
  next_samples_0 = current_samples + dx_0 * dt
18
18
 
19
19
  # Recompute x_0 and eps at the first estimate to refine the derivative
20
- estimated_x_0, _, _ = self.sample_model(next_samples_0, next_step, *model_conditioning_inputs)
20
+ estimated_x_0, _, _ = sample_model_fn(next_samples_0, next_step, *model_conditioning_inputs)
21
21
 
22
22
  # Estimate the refined derivative using the midpoint (Heun's method)
23
23
  dx_1 = (next_samples_0 - x_0_coeff * estimated_x_0) / next_sigma
@@ -9,7 +9,7 @@ class MultiStepDPM(DiffusionSampler):
9
9
  self.history = []
10
10
 
11
11
  def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
12
- pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
12
+ pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
13
13
  # Get the noise and signal rates for the current and next steps
14
14
  current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
15
15
  next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
@@ -9,14 +9,14 @@ class RK4Sampler(DiffusionSampler):
9
9
  super().__init__(*args, **kwargs)
10
10
  assert issubclass(type(self.noise_schedule), GeneralizedNoiseScheduler), "Noise schedule must be a GeneralizedNoiseScheduler"
11
11
  @jax.jit
12
- def get_derivative(x_t, sigma, state:RandomMarkovState, model_conditioning_inputs) -> tuple[jnp.ndarray, RandomMarkovState]:
12
+ def get_derivative(sample_model_fn, x_t, sigma, state:RandomMarkovState, model_conditioning_inputs) -> tuple[jnp.ndarray, RandomMarkovState]:
13
13
  t = self.noise_schedule.get_timesteps(sigma)
14
- x_0, eps, _ = self.sample_model(x_t, t, *model_conditioning_inputs)
14
+ x_0, eps, _ = sample_model_fn(x_t, t, *model_conditioning_inputs)
15
15
  return eps, state
16
16
 
17
17
  self.get_derivative = get_derivative
18
18
 
19
- def sample_step(self, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
19
+ def sample_step(self, sample_model_fn, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
20
20
  step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32)
21
21
  current_step = step_ones * current_step
22
22
  next_step = step_ones * next_step
@@ -25,10 +25,10 @@ class RK4Sampler(DiffusionSampler):
25
25
 
26
26
  dt = next_sigma - current_sigma
27
27
 
28
- k1, state = self.get_derivative(current_samples, current_sigma, state, model_conditioning_inputs)
29
- k2, state = self.get_derivative(current_samples + 0.5 * k1 * dt, current_sigma + 0.5 * dt, state, model_conditioning_inputs)
30
- k3, state = self.get_derivative(current_samples + 0.5 * k2 * dt, current_sigma + 0.5 * dt, state, model_conditioning_inputs)
31
- k4, state = self.get_derivative(current_samples + k3 * dt, current_sigma + dt, state, model_conditioning_inputs)
28
+ k1, state = self.get_derivative(sample_model_fn, current_samples, current_sigma, state, model_conditioning_inputs)
29
+ k2, state = self.get_derivative(sample_model_fn, current_samples + 0.5 * k1 * dt, current_sigma + 0.5 * dt, state, model_conditioning_inputs)
30
+ k3, state = self.get_derivative(sample_model_fn, current_samples + 0.5 * k2 * dt, current_sigma + 0.5 * dt, state, model_conditioning_inputs)
31
+ k4, state = self.get_derivative(sample_model_fn, current_samples + k3 * dt, current_sigma + dt, state, model_conditioning_inputs)
32
32
 
33
33
  next_samples = current_samples + (((k1 + 2 * k2 + 2 * k3 + k4) * dt) / 6)
34
34
  return next_samples, state
@@ -1,6 +1,6 @@
1
1
  from .discrete import DiscreteNoiseScheduler
2
2
  from .common import NoiseScheduler, GeneralizedNoiseScheduler
3
- from .cosine import CosineNoiseSchedule, ContinuousNoiseScheduler, CosineGeneralNoiseScheduler
3
+ from .cosine import CosineNoiseScheduler, ContinuousNoiseScheduler, CosineGeneralNoiseScheduler
4
4
  from .linear import LinearNoiseSchedule
5
5
  from .sqrt import SqrtContinuousNoiseScheduler
6
6
  from .karras import KarrasVENoiseScheduler, SimpleExpNoiseScheduler, EDMNoiseScheduler
@@ -12,7 +12,7 @@ def cosine_beta_schedule(timesteps, start_angle=0.008, end_angle=0.999):
12
12
  betas = 1 - (alphas_bar[1:] / alphas_bar[:-1])
13
13
  return np.clip(betas, 0, end_angle)
14
14
 
15
- class CosineNoiseSchedule(DiscreteNoiseScheduler):
15
+ class CosineNoiseScheduler(DiscreteNoiseScheduler):
16
16
  def __init__(self, timesteps, beta_start=0.008, beta_end=0.999, *args, **kwargs):
17
17
  super().__init__(timesteps, beta_start, beta_end, schedule_fn=cosine_beta_schedule, *args, **kwargs)
18
18
 
@@ -235,19 +235,19 @@ class DiffusionTrainer(SimpleTrainer):
235
235
  null_labels_full = null_labels_full.astype(jnp.float16)
236
236
  # null_labels_seq = jnp.array(null_labels_full[0], dtype=jnp.float16)
237
237
 
238
- def generate_sampler(state: TrainState):
239
- sampler = sampler_class(
240
- model=model,
241
- params=state.ema_params,
242
- noise_schedule=self.noise_schedule if sampling_noise_schedule is None else sampling_noise_schedule,
243
- model_output_transform=self.model_output_transform,
244
- image_size=self.input_shapes['x'][0],
245
- null_labels_seq=null_labels_full,
246
- autoencoder=autoencoder,
247
- )
248
- return sampler
238
+ sampler = sampler_class(
239
+ model=model,
240
+ params=None,
241
+ noise_schedule=self.noise_schedule if sampling_noise_schedule is None else sampling_noise_schedule,
242
+ model_output_transform=self.model_output_transform,
243
+ image_size=self.input_shapes['x'][0],
244
+ null_labels_seq=null_labels_full,
245
+ autoencoder=autoencoder,
246
+ guidance_scale=3.0,
247
+ )
249
248
 
250
249
  def generate_samples(
250
+ val_state: TrainState,
251
251
  batch,
252
252
  sampler: DiffusionSampler,
253
253
  diffusion_steps: int,
@@ -255,6 +255,7 @@ class DiffusionTrainer(SimpleTrainer):
255
255
  labels_seq = encoder.encode_from_tokens(batch)
256
256
  labels_seq = jnp.array(labels_seq, dtype=jnp.float16)
257
257
  samples = sampler.generate_images(
258
+ params=val_state.ema_params,
258
259
  num_images=len(labels_seq),
259
260
  diffusion_steps=diffusion_steps,
260
261
  start_step=1000,
@@ -264,7 +265,7 @@ class DiffusionTrainer(SimpleTrainer):
264
265
  )
265
266
  return samples
266
267
 
267
- return generate_sampler, generate_samples
268
+ return sampler, generate_samples
268
269
 
269
270
  def validation_loop(
270
271
  self,
@@ -275,14 +276,15 @@ class DiffusionTrainer(SimpleTrainer):
275
276
  current_step,
276
277
  diffusion_steps=200,
277
278
  ):
278
- generate_sampler, generate_samples = val_step_fn
279
+ sampler, generate_samples = val_step_fn
279
280
 
280
- sampler = generate_sampler(val_state)
281
+ # sampler = generate_sampler(val_state)
281
282
 
282
283
  val_ds = iter(val_ds()) if val_ds else None
283
284
  # Evaluation step
284
285
  try:
285
286
  samples = generate_samples(
287
+ val_state,
286
288
  next(val_ds),
287
289
  sampler,
288
290
  diffusion_steps,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.36.4
3
+ Version: 0.1.37
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -96,7 +96,7 @@ Also, few of the text may be generated with help of github copilot, so please ex
96
96
  ### Schedulers
97
97
  Implemented in `flaxdiff.schedulers`:
98
98
  - **LinearNoiseSchedule** (`flaxdiff.schedulers.LinearNoiseSchedule`): A beta-parameterized discrete scheduler.
99
- - **CosineNoiseSchedule** (`flaxdiff.schedulers.CosineNoiseSchedule`): A beta-parameterized discrete scheduler.
99
+ - **CosineNoiseScheduler** (`flaxdiff.schedulers.CosineNoiseScheduler`): A beta-parameterized discrete scheduler.
100
100
  - **ExpNoiseSchedule** (`flaxdiff.schedulers.ExpNoiseSchedule`): A beta-parameterized discrete scheduler.
101
101
  - **CosineContinuousNoiseScheduler** (`flaxdiff.schedulers.CosineContinuousNoiseScheduler`): A continuous scheduler.
102
102
  - **CosineGeneralNoiseScheduler** (`flaxdiff.schedulers.CosineGeneralNoiseScheduler`): A continuous sigma parameterized cosine scheduler.
@@ -147,43 +147,81 @@ sticking to the versions mentioned in the requirements.txt
147
147
  Here is a simplified example to get you started with training a diffusion model using FlaxDiff:
148
148
 
149
149
  ```python
150
- from flaxdiff.schedulers import EDMNoiseScheduler
150
+ from flaxdiff.schedulers import EDMNoiseScheduler, KarrasVENoiseScheduler
151
151
  from flaxdiff.predictors import KarrasPredictionTransform
152
- from flaxdiff.models.simple_unet import SimpleUNet as UNet
152
+ from flaxdiff.models.simple_unet import Unet
153
153
  from flaxdiff.trainer import DiffusionTrainer
154
+ from flaxdiff.data.datasets import get_dataset_grain
155
+ from flaxdiff.utils import defaultTextEncodeModel
156
+ from flaxdiff.samplers.euler import EulerAncestralSampler
154
157
  import jax
158
+ import jax.numpy as jnp
155
159
  import optax
156
160
  from datetime import datetime
157
161
 
158
162
  BATCH_SIZE = 16
159
- IMAGE_SIZE = 64
163
+ IMAGE_SIZE = 128
160
164
 
161
165
  # Define noise scheduler
162
166
  edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
163
-
167
+ karas_ve_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
164
168
  # Define model
165
- unet = UNet(emb_features=256,
166
- feature_depths=[64, 128, 256, 512],
167
- attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}],
169
+ unet = Unet(emb_features=256,
170
+ feature_depths=[64, 64, 128, 256, 512],
171
+ attention_configs=[
172
+ None,
173
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
174
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
175
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
176
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":False, "use_self_and_cross":False}
177
+ ],
168
178
  num_res_blocks=2,
169
- num_middle_res_blocks=1)
170
-
179
+ num_middle_res_blocks=1
180
+ )
171
181
  # Load dataset
172
- data, datalen = get_dataset("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
182
+ data = get_dataset_grain("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
183
+ datalen = data['train_len']
173
184
  batches = datalen // BATCH_SIZE
174
185
 
186
+ input_shapes = {
187
+ "x": (IMAGE_SIZE, IMAGE_SIZE, 3),
188
+ "temb": (),
189
+ "textcontext": (77, 768)
190
+ }
191
+ text_encoder = defaultTextEncodeModel()
192
+
193
+ # Construct a validation set by the prompts
194
+ val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold']
195
+
196
+ def get_val_dataset(batch_size=8):
197
+ for i in range(0, len(val_prompts), batch_size):
198
+ prompts = val_prompts[i:i + batch_size]
199
+ tokens = text_encoder.tokenize(prompts)
200
+ yield tokens
201
+
202
+ data['test'] = get_val_dataset
203
+ data['test_len'] = len(val_prompts)
204
+
175
205
  # Define optimizer
176
206
  solver = optax.adam(2e-4)
177
207
 
178
208
  # Create trainer
179
- trainer = DiffusionTrainer(unet, optimizer=solver,
180
- noise_schedule=edm_schedule,
181
- rngs=jax.random.PRNGKey(4),
182
- name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
183
- model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data))
209
+ trainer = DiffusionTrainer(
210
+ unet, optimizer=solver,
211
+ input_shapes=input_shapes,
212
+ noise_schedule=edm_schedule,
213
+ rngs=jax.random.PRNGKey(4),
214
+ name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
215
+ model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data),
216
+ encoder=text_encoder,
217
+ distributed_training=True,
218
+ wandb_config = {
219
+ "project": 'mlops-msml605-project',
220
+ "name": f"prototype-{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}",
221
+ })
184
222
 
185
223
  # Train the model
186
- final_state = trainer.fit(data, batches, epochs=2000)
224
+ final_state = trainer.fit(data, batches, epochs=2000, sampler_class=EulerAncestralSampler, sampling_noise_schedule=karas_ve_schedule)
187
225
  ```
188
226
 
189
227
  ### Inference Example
@@ -301,8 +339,8 @@ Images generated by the following prompts using classifier free guidance with gu
301
339
  `Training Epochs: 1000`
302
340
  `Steps per epoch: 511`
303
341
 
304
- `Training Noise Schedule: CosineNoiseSchedule`
305
- `Inference Noise Schedule: CosineNoiseSchedule`
342
+ `Training Noise Schedule: CosineNoiseScheduler`
343
+ `Inference Noise Schedule: CosineNoiseScheduler`
306
344
 
307
345
  `Model: UNet(emb_features=256,
308
346
  feature_depths=[64, 128, 256, 512],
@@ -321,8 +359,8 @@ Images generated by the following prompts using classifier free guidance with gu
321
359
  `Training Epochs: 1000`
322
360
  `Steps per epoch: 511`
323
361
 
324
- `Training Noise Schedule: CosineNoiseSchedule`
325
- `Inference Noise Schedule: CosineNoiseSchedule`
362
+ `Training Noise Schedule: CosineNoiseScheduler`
363
+ `Inference Noise Schedule: CosineNoiseScheduler`
326
364
 
327
365
  `Model: UNet(emb_features=256,
328
366
  feature_depths=[64, 128, 256, 512],
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "flaxdiff"
7
- version = "0.1.36.4"
7
+ version = "0.1.37"
8
8
  description = "A versatile and easy to understand Diffusion library"
9
9
  readme = "README.md"
10
10
  authors = [
File without changes
File without changes