flaxdiff 0.1.36.4__tar.gz → 0.1.36.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/PKG-INFO +60 -22
  2. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/README.md +59 -21
  3. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/schedulers/__init__.py +1 -1
  4. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/schedulers/cosine.py +1 -1
  5. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff.egg-info/PKG-INFO +60 -22
  6. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/pyproject.toml +1 -1
  7. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/__init__.py +0 -0
  8. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/data/__init__.py +0 -0
  9. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/data/dataset_map.py +0 -0
  10. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/data/datasets.py +0 -0
  11. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/data/online_loader.py +0 -0
  12. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/data/sources/gcs.py +0 -0
  13. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/data/sources/tfds.py +0 -0
  14. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/metrics/inception.py +0 -0
  15. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/metrics/utils.py +0 -0
  16. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/models/__init__.py +0 -0
  17. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/models/attention.py +0 -0
  18. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/models/autoencoder/__init__.py +0 -0
  19. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
  20. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/models/autoencoder/diffusers.py +0 -0
  21. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
  22. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/models/common.py +0 -0
  23. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/models/favor_fastattn.py +0 -0
  24. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/models/simple_unet.py +0 -0
  25. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/models/simple_vit.py +0 -0
  26. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/predictors/__init__.py +0 -0
  27. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/samplers/__init__.py +0 -0
  28. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/samplers/common.py +0 -0
  29. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/samplers/ddim.py +0 -0
  30. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/samplers/ddpm.py +0 -0
  31. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/samplers/euler.py +0 -0
  32. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/samplers/heun_sampler.py +0 -0
  33. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/samplers/multistep_dpm.py +0 -0
  34. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/samplers/rk4_sampler.py +0 -0
  35. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/schedulers/common.py +0 -0
  36. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/schedulers/continuous.py +0 -0
  37. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/schedulers/discrete.py +0 -0
  38. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/schedulers/exp.py +0 -0
  39. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/schedulers/karras.py +0 -0
  40. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/schedulers/linear.py +0 -0
  41. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/schedulers/sqrt.py +0 -0
  42. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/trainer/__init__.py +0 -0
  43. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
  44. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/trainer/diffusion_trainer.py +0 -0
  45. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/trainer/simple_trainer.py +0 -0
  46. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/trainer/video_diffusion_trainer.py +0 -0
  47. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff/utils.py +0 -0
  48. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff.egg-info/SOURCES.txt +0 -0
  49. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff.egg-info/dependency_links.txt +0 -0
  50. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff.egg-info/requires.txt +0 -0
  51. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/flaxdiff.egg-info/top_level.txt +0 -0
  52. {flaxdiff-0.1.36.4 → flaxdiff-0.1.36.5}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.36.4
3
+ Version: 0.1.36.5
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -96,7 +96,7 @@ Also, few of the text may be generated with help of github copilot, so please ex
96
96
  ### Schedulers
97
97
  Implemented in `flaxdiff.schedulers`:
98
98
  - **LinearNoiseSchedule** (`flaxdiff.schedulers.LinearNoiseSchedule`): A beta-parameterized discrete scheduler.
99
- - **CosineNoiseSchedule** (`flaxdiff.schedulers.CosineNoiseSchedule`): A beta-parameterized discrete scheduler.
99
+ - **CosineNoiseScheduler** (`flaxdiff.schedulers.CosineNoiseScheduler`): A beta-parameterized discrete scheduler.
100
100
  - **ExpNoiseSchedule** (`flaxdiff.schedulers.ExpNoiseSchedule`): A beta-parameterized discrete scheduler.
101
101
  - **CosineContinuousNoiseScheduler** (`flaxdiff.schedulers.CosineContinuousNoiseScheduler`): A continuous scheduler.
102
102
  - **CosineGeneralNoiseScheduler** (`flaxdiff.schedulers.CosineGeneralNoiseScheduler`): A continuous sigma parameterized cosine scheduler.
@@ -147,43 +147,81 @@ sticking to the versions mentioned in the requirements.txt
147
147
  Here is a simplified example to get you started with training a diffusion model using FlaxDiff:
148
148
 
149
149
  ```python
150
- from flaxdiff.schedulers import EDMNoiseScheduler
150
+ from flaxdiff.schedulers import EDMNoiseScheduler, KarrasVENoiseScheduler
151
151
  from flaxdiff.predictors import KarrasPredictionTransform
152
- from flaxdiff.models.simple_unet import SimpleUNet as UNet
152
+ from flaxdiff.models.simple_unet import Unet
153
153
  from flaxdiff.trainer import DiffusionTrainer
154
+ from flaxdiff.data.datasets import get_dataset_grain
155
+ from flaxdiff.utils import defaultTextEncodeModel
156
+ from flaxdiff.samplers.euler import EulerAncestralSampler
154
157
  import jax
158
+ import jax.numpy as jnp
155
159
  import optax
156
160
  from datetime import datetime
157
161
 
158
162
  BATCH_SIZE = 16
159
- IMAGE_SIZE = 64
163
+ IMAGE_SIZE = 128
160
164
 
161
165
  # Define noise scheduler
162
166
  edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
163
-
167
+ karas_ve_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
164
168
  # Define model
165
- unet = UNet(emb_features=256,
166
- feature_depths=[64, 128, 256, 512],
167
- attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}],
169
+ unet = Unet(emb_features=256,
170
+ feature_depths=[64, 64, 128, 256, 512],
171
+ attention_configs=[
172
+ None,
173
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
174
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
175
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
176
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":False, "use_self_and_cross":False}
177
+ ],
168
178
  num_res_blocks=2,
169
- num_middle_res_blocks=1)
170
-
179
+ num_middle_res_blocks=1
180
+ )
171
181
  # Load dataset
172
- data, datalen = get_dataset("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
182
+ data = get_dataset_grain("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
183
+ datalen = data['train_len']
173
184
  batches = datalen // BATCH_SIZE
174
185
 
186
+ input_shapes = {
187
+ "x": (IMAGE_SIZE, IMAGE_SIZE, 3),
188
+ "temb": (),
189
+ "textcontext": (77, 768)
190
+ }
191
+ text_encoder = defaultTextEncodeModel()
192
+
193
+ # Construct a validation set by the prompts
194
+ val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold']
195
+
196
+ def get_val_dataset(batch_size=8):
197
+ for i in range(0, len(val_prompts), batch_size):
198
+ prompts = val_prompts[i:i + batch_size]
199
+ tokens = text_encoder.tokenize(prompts)
200
+ yield tokens
201
+
202
+ data['test'] = get_val_dataset
203
+ data['test_len'] = len(val_prompts)
204
+
175
205
  # Define optimizer
176
206
  solver = optax.adam(2e-4)
177
207
 
178
208
  # Create trainer
179
- trainer = DiffusionTrainer(unet, optimizer=solver,
180
- noise_schedule=edm_schedule,
181
- rngs=jax.random.PRNGKey(4),
182
- name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
183
- model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data))
209
+ trainer = DiffusionTrainer(
210
+ unet, optimizer=solver,
211
+ input_shapes=input_shapes,
212
+ noise_schedule=edm_schedule,
213
+ rngs=jax.random.PRNGKey(4),
214
+ name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
215
+ model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data),
216
+ encoder=text_encoder,
217
+ distributed_training=True,
218
+ wandb_config = {
219
+ "project": 'mlops-msml605-project',
220
+ "name": f"prototype-{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}",
221
+ })
184
222
 
185
223
  # Train the model
186
- final_state = trainer.fit(data, batches, epochs=2000)
224
+ final_state = trainer.fit(data, batches, epochs=2000, sampler_class=EulerAncestralSampler, sampling_noise_schedule=karas_ve_schedule)
187
225
  ```
188
226
 
189
227
  ### Inference Example
@@ -301,8 +339,8 @@ Images generated by the following prompts using classifier free guidance with gu
301
339
  `Training Epochs: 1000`
302
340
  `Steps per epoch: 511`
303
341
 
304
- `Training Noise Schedule: CosineNoiseSchedule`
305
- `Inference Noise Schedule: CosineNoiseSchedule`
342
+ `Training Noise Schedule: CosineNoiseScheduler`
343
+ `Inference Noise Schedule: CosineNoiseScheduler`
306
344
 
307
345
  `Model: UNet(emb_features=256,
308
346
  feature_depths=[64, 128, 256, 512],
@@ -321,8 +359,8 @@ Images generated by the following prompts using classifier free guidance with gu
321
359
  `Training Epochs: 1000`
322
360
  `Steps per epoch: 511`
323
361
 
324
- `Training Noise Schedule: CosineNoiseSchedule`
325
- `Inference Noise Schedule: CosineNoiseSchedule`
362
+ `Training Noise Schedule: CosineNoiseScheduler`
363
+ `Inference Noise Schedule: CosineNoiseScheduler`
326
364
 
327
365
  `Model: UNet(emb_features=256,
328
366
  feature_depths=[64, 128, 256, 512],
@@ -74,7 +74,7 @@ Also, few of the text may be generated with help of github copilot, so please ex
74
74
  ### Schedulers
75
75
  Implemented in `flaxdiff.schedulers`:
76
76
  - **LinearNoiseSchedule** (`flaxdiff.schedulers.LinearNoiseSchedule`): A beta-parameterized discrete scheduler.
77
- - **CosineNoiseSchedule** (`flaxdiff.schedulers.CosineNoiseSchedule`): A beta-parameterized discrete scheduler.
77
+ - **CosineNoiseScheduler** (`flaxdiff.schedulers.CosineNoiseScheduler`): A beta-parameterized discrete scheduler.
78
78
  - **ExpNoiseSchedule** (`flaxdiff.schedulers.ExpNoiseSchedule`): A beta-parameterized discrete scheduler.
79
79
  - **CosineContinuousNoiseScheduler** (`flaxdiff.schedulers.CosineContinuousNoiseScheduler`): A continuous scheduler.
80
80
  - **CosineGeneralNoiseScheduler** (`flaxdiff.schedulers.CosineGeneralNoiseScheduler`): A continuous sigma parameterized cosine scheduler.
@@ -125,43 +125,81 @@ sticking to the versions mentioned in the requirements.txt
125
125
  Here is a simplified example to get you started with training a diffusion model using FlaxDiff:
126
126
 
127
127
  ```python
128
- from flaxdiff.schedulers import EDMNoiseScheduler
128
+ from flaxdiff.schedulers import EDMNoiseScheduler, KarrasVENoiseScheduler
129
129
  from flaxdiff.predictors import KarrasPredictionTransform
130
- from flaxdiff.models.simple_unet import SimpleUNet as UNet
130
+ from flaxdiff.models.simple_unet import Unet
131
131
  from flaxdiff.trainer import DiffusionTrainer
132
+ from flaxdiff.data.datasets import get_dataset_grain
133
+ from flaxdiff.utils import defaultTextEncodeModel
134
+ from flaxdiff.samplers.euler import EulerAncestralSampler
132
135
  import jax
136
+ import jax.numpy as jnp
133
137
  import optax
134
138
  from datetime import datetime
135
139
 
136
140
  BATCH_SIZE = 16
137
- IMAGE_SIZE = 64
141
+ IMAGE_SIZE = 128
138
142
 
139
143
  # Define noise scheduler
140
144
  edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
141
-
145
+ karas_ve_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
142
146
  # Define model
143
- unet = UNet(emb_features=256,
144
- feature_depths=[64, 128, 256, 512],
145
- attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}],
147
+ unet = Unet(emb_features=256,
148
+ feature_depths=[64, 64, 128, 256, 512],
149
+ attention_configs=[
150
+ None,
151
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
152
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
153
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
154
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":False, "use_self_and_cross":False}
155
+ ],
146
156
  num_res_blocks=2,
147
- num_middle_res_blocks=1)
148
-
157
+ num_middle_res_blocks=1
158
+ )
149
159
  # Load dataset
150
- data, datalen = get_dataset("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
160
+ data = get_dataset_grain("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
161
+ datalen = data['train_len']
151
162
  batches = datalen // BATCH_SIZE
152
163
 
164
+ input_shapes = {
165
+ "x": (IMAGE_SIZE, IMAGE_SIZE, 3),
166
+ "temb": (),
167
+ "textcontext": (77, 768)
168
+ }
169
+ text_encoder = defaultTextEncodeModel()
170
+
171
+ # Construct a validation set by the prompts
172
+ val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold']
173
+
174
+ def get_val_dataset(batch_size=8):
175
+ for i in range(0, len(val_prompts), batch_size):
176
+ prompts = val_prompts[i:i + batch_size]
177
+ tokens = text_encoder.tokenize(prompts)
178
+ yield tokens
179
+
180
+ data['test'] = get_val_dataset
181
+ data['test_len'] = len(val_prompts)
182
+
153
183
  # Define optimizer
154
184
  solver = optax.adam(2e-4)
155
185
 
156
186
  # Create trainer
157
- trainer = DiffusionTrainer(unet, optimizer=solver,
158
- noise_schedule=edm_schedule,
159
- rngs=jax.random.PRNGKey(4),
160
- name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
161
- model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data))
187
+ trainer = DiffusionTrainer(
188
+ unet, optimizer=solver,
189
+ input_shapes=input_shapes,
190
+ noise_schedule=edm_schedule,
191
+ rngs=jax.random.PRNGKey(4),
192
+ name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
193
+ model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data),
194
+ encoder=text_encoder,
195
+ distributed_training=True,
196
+ wandb_config = {
197
+ "project": 'mlops-msml605-project',
198
+ "name": f"prototype-{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}",
199
+ })
162
200
 
163
201
  # Train the model
164
- final_state = trainer.fit(data, batches, epochs=2000)
202
+ final_state = trainer.fit(data, batches, epochs=2000, sampler_class=EulerAncestralSampler, sampling_noise_schedule=karas_ve_schedule)
165
203
  ```
166
204
 
167
205
  ### Inference Example
@@ -279,8 +317,8 @@ Images generated by the following prompts using classifier free guidance with gu
279
317
  `Training Epochs: 1000`
280
318
  `Steps per epoch: 511`
281
319
 
282
- `Training Noise Schedule: CosineNoiseSchedule`
283
- `Inference Noise Schedule: CosineNoiseSchedule`
320
+ `Training Noise Schedule: CosineNoiseScheduler`
321
+ `Inference Noise Schedule: CosineNoiseScheduler`
284
322
 
285
323
  `Model: UNet(emb_features=256,
286
324
  feature_depths=[64, 128, 256, 512],
@@ -299,8 +337,8 @@ Images generated by the following prompts using classifier free guidance with gu
299
337
  `Training Epochs: 1000`
300
338
  `Steps per epoch: 511`
301
339
 
302
- `Training Noise Schedule: CosineNoiseSchedule`
303
- `Inference Noise Schedule: CosineNoiseSchedule`
340
+ `Training Noise Schedule: CosineNoiseScheduler`
341
+ `Inference Noise Schedule: CosineNoiseScheduler`
304
342
 
305
343
  `Model: UNet(emb_features=256,
306
344
  feature_depths=[64, 128, 256, 512],
@@ -1,6 +1,6 @@
1
1
  from .discrete import DiscreteNoiseScheduler
2
2
  from .common import NoiseScheduler, GeneralizedNoiseScheduler
3
- from .cosine import CosineNoiseSchedule, ContinuousNoiseScheduler, CosineGeneralNoiseScheduler
3
+ from .cosine import CosineNoiseScheduler, ContinuousNoiseScheduler, CosineGeneralNoiseScheduler
4
4
  from .linear import LinearNoiseSchedule
5
5
  from .sqrt import SqrtContinuousNoiseScheduler
6
6
  from .karras import KarrasVENoiseScheduler, SimpleExpNoiseScheduler, EDMNoiseScheduler
@@ -12,7 +12,7 @@ def cosine_beta_schedule(timesteps, start_angle=0.008, end_angle=0.999):
12
12
  betas = 1 - (alphas_bar[1:] / alphas_bar[:-1])
13
13
  return np.clip(betas, 0, end_angle)
14
14
 
15
- class CosineNoiseSchedule(DiscreteNoiseScheduler):
15
+ class CosineNoiseScheduler(DiscreteNoiseScheduler):
16
16
  def __init__(self, timesteps, beta_start=0.008, beta_end=0.999, *args, **kwargs):
17
17
  super().__init__(timesteps, beta_start, beta_end, schedule_fn=cosine_beta_schedule, *args, **kwargs)
18
18
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.36.4
3
+ Version: 0.1.36.5
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -96,7 +96,7 @@ Also, few of the text may be generated with help of github copilot, so please ex
96
96
  ### Schedulers
97
97
  Implemented in `flaxdiff.schedulers`:
98
98
  - **LinearNoiseSchedule** (`flaxdiff.schedulers.LinearNoiseSchedule`): A beta-parameterized discrete scheduler.
99
- - **CosineNoiseSchedule** (`flaxdiff.schedulers.CosineNoiseSchedule`): A beta-parameterized discrete scheduler.
99
+ - **CosineNoiseScheduler** (`flaxdiff.schedulers.CosineNoiseScheduler`): A beta-parameterized discrete scheduler.
100
100
  - **ExpNoiseSchedule** (`flaxdiff.schedulers.ExpNoiseSchedule`): A beta-parameterized discrete scheduler.
101
101
  - **CosineContinuousNoiseScheduler** (`flaxdiff.schedulers.CosineContinuousNoiseScheduler`): A continuous scheduler.
102
102
  - **CosineGeneralNoiseScheduler** (`flaxdiff.schedulers.CosineGeneralNoiseScheduler`): A continuous sigma parameterized cosine scheduler.
@@ -147,43 +147,81 @@ sticking to the versions mentioned in the requirements.txt
147
147
  Here is a simplified example to get you started with training a diffusion model using FlaxDiff:
148
148
 
149
149
  ```python
150
- from flaxdiff.schedulers import EDMNoiseScheduler
150
+ from flaxdiff.schedulers import EDMNoiseScheduler, KarrasVENoiseScheduler
151
151
  from flaxdiff.predictors import KarrasPredictionTransform
152
- from flaxdiff.models.simple_unet import SimpleUNet as UNet
152
+ from flaxdiff.models.simple_unet import Unet
153
153
  from flaxdiff.trainer import DiffusionTrainer
154
+ from flaxdiff.data.datasets import get_dataset_grain
155
+ from flaxdiff.utils import defaultTextEncodeModel
156
+ from flaxdiff.samplers.euler import EulerAncestralSampler
154
157
  import jax
158
+ import jax.numpy as jnp
155
159
  import optax
156
160
  from datetime import datetime
157
161
 
158
162
  BATCH_SIZE = 16
159
- IMAGE_SIZE = 64
163
+ IMAGE_SIZE = 128
160
164
 
161
165
  # Define noise scheduler
162
166
  edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
163
-
167
+ karas_ve_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
164
168
  # Define model
165
- unet = UNet(emb_features=256,
166
- feature_depths=[64, 128, 256, 512],
167
- attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}],
169
+ unet = Unet(emb_features=256,
170
+ feature_depths=[64, 64, 128, 256, 512],
171
+ attention_configs=[
172
+ None,
173
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
174
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
175
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
176
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":False, "use_self_and_cross":False}
177
+ ],
168
178
  num_res_blocks=2,
169
- num_middle_res_blocks=1)
170
-
179
+ num_middle_res_blocks=1
180
+ )
171
181
  # Load dataset
172
- data, datalen = get_dataset("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
182
+ data = get_dataset_grain("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
183
+ datalen = data['train_len']
173
184
  batches = datalen // BATCH_SIZE
174
185
 
186
+ input_shapes = {
187
+ "x": (IMAGE_SIZE, IMAGE_SIZE, 3),
188
+ "temb": (),
189
+ "textcontext": (77, 768)
190
+ }
191
+ text_encoder = defaultTextEncodeModel()
192
+
193
+ # Construct a validation set by the prompts
194
+ val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold']
195
+
196
+ def get_val_dataset(batch_size=8):
197
+ for i in range(0, len(val_prompts), batch_size):
198
+ prompts = val_prompts[i:i + batch_size]
199
+ tokens = text_encoder.tokenize(prompts)
200
+ yield tokens
201
+
202
+ data['test'] = get_val_dataset
203
+ data['test_len'] = len(val_prompts)
204
+
175
205
  # Define optimizer
176
206
  solver = optax.adam(2e-4)
177
207
 
178
208
  # Create trainer
179
- trainer = DiffusionTrainer(unet, optimizer=solver,
180
- noise_schedule=edm_schedule,
181
- rngs=jax.random.PRNGKey(4),
182
- name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
183
- model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data))
209
+ trainer = DiffusionTrainer(
210
+ unet, optimizer=solver,
211
+ input_shapes=input_shapes,
212
+ noise_schedule=edm_schedule,
213
+ rngs=jax.random.PRNGKey(4),
214
+ name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
215
+ model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data),
216
+ encoder=text_encoder,
217
+ distributed_training=True,
218
+ wandb_config = {
219
+ "project": 'mlops-msml605-project',
220
+ "name": f"prototype-{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}",
221
+ })
184
222
 
185
223
  # Train the model
186
- final_state = trainer.fit(data, batches, epochs=2000)
224
+ final_state = trainer.fit(data, batches, epochs=2000, sampler_class=EulerAncestralSampler, sampling_noise_schedule=karas_ve_schedule)
187
225
  ```
188
226
 
189
227
  ### Inference Example
@@ -301,8 +339,8 @@ Images generated by the following prompts using classifier free guidance with gu
301
339
  `Training Epochs: 1000`
302
340
  `Steps per epoch: 511`
303
341
 
304
- `Training Noise Schedule: CosineNoiseSchedule`
305
- `Inference Noise Schedule: CosineNoiseSchedule`
342
+ `Training Noise Schedule: CosineNoiseScheduler`
343
+ `Inference Noise Schedule: CosineNoiseScheduler`
306
344
 
307
345
  `Model: UNet(emb_features=256,
308
346
  feature_depths=[64, 128, 256, 512],
@@ -321,8 +359,8 @@ Images generated by the following prompts using classifier free guidance with gu
321
359
  `Training Epochs: 1000`
322
360
  `Steps per epoch: 511`
323
361
 
324
- `Training Noise Schedule: CosineNoiseSchedule`
325
- `Inference Noise Schedule: CosineNoiseSchedule`
362
+ `Training Noise Schedule: CosineNoiseScheduler`
363
+ `Inference Noise Schedule: CosineNoiseScheduler`
326
364
 
327
365
  `Model: UNet(emb_features=256,
328
366
  feature_depths=[64, 128, 256, 512],
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "flaxdiff"
7
- version = "0.1.36.4"
7
+ version = "0.1.36.5"
8
8
  description = "A versatile and easy to understand Diffusion library"
9
9
  readme = "README.md"
10
10
  authors = [
File without changes