flaxdiff 0.1.36.4__py3-none-any.whl → 0.1.36.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  from .discrete import DiscreteNoiseScheduler
2
2
  from .common import NoiseScheduler, GeneralizedNoiseScheduler
3
- from .cosine import CosineNoiseSchedule, ContinuousNoiseScheduler, CosineGeneralNoiseScheduler
3
+ from .cosine import CosineNoiseScheduler, ContinuousNoiseScheduler, CosineGeneralNoiseScheduler
4
4
  from .linear import LinearNoiseSchedule
5
5
  from .sqrt import SqrtContinuousNoiseScheduler
6
6
  from .karras import KarrasVENoiseScheduler, SimpleExpNoiseScheduler, EDMNoiseScheduler
@@ -12,7 +12,7 @@ def cosine_beta_schedule(timesteps, start_angle=0.008, end_angle=0.999):
12
12
  betas = 1 - (alphas_bar[1:] / alphas_bar[:-1])
13
13
  return np.clip(betas, 0, end_angle)
14
14
 
15
- class CosineNoiseSchedule(DiscreteNoiseScheduler):
15
+ class CosineNoiseScheduler(DiscreteNoiseScheduler):
16
16
  def __init__(self, timesteps, beta_start=0.008, beta_end=0.999, *args, **kwargs):
17
17
  super().__init__(timesteps, beta_start, beta_end, schedule_fn=cosine_beta_schedule, *args, **kwargs)
18
18
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.36.4
3
+ Version: 0.1.36.5
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -96,7 +96,7 @@ Also, few of the text may be generated with help of github copilot, so please ex
96
96
  ### Schedulers
97
97
  Implemented in `flaxdiff.schedulers`:
98
98
  - **LinearNoiseSchedule** (`flaxdiff.schedulers.LinearNoiseSchedule`): A beta-parameterized discrete scheduler.
99
- - **CosineNoiseSchedule** (`flaxdiff.schedulers.CosineNoiseSchedule`): A beta-parameterized discrete scheduler.
99
+ - **CosineNoiseScheduler** (`flaxdiff.schedulers.CosineNoiseScheduler`): A beta-parameterized discrete scheduler.
100
100
  - **ExpNoiseSchedule** (`flaxdiff.schedulers.ExpNoiseSchedule`): A beta-parameterized discrete scheduler.
101
101
  - **CosineContinuousNoiseScheduler** (`flaxdiff.schedulers.CosineContinuousNoiseScheduler`): A continuous scheduler.
102
102
  - **CosineGeneralNoiseScheduler** (`flaxdiff.schedulers.CosineGeneralNoiseScheduler`): A continuous sigma parameterized cosine scheduler.
@@ -147,43 +147,81 @@ sticking to the versions mentioned in the requirements.txt
147
147
  Here is a simplified example to get you started with training a diffusion model using FlaxDiff:
148
148
 
149
149
  ```python
150
- from flaxdiff.schedulers import EDMNoiseScheduler
150
+ from flaxdiff.schedulers import EDMNoiseScheduler, KarrasVENoiseScheduler
151
151
  from flaxdiff.predictors import KarrasPredictionTransform
152
- from flaxdiff.models.simple_unet import SimpleUNet as UNet
152
+ from flaxdiff.models.simple_unet import Unet
153
153
  from flaxdiff.trainer import DiffusionTrainer
154
+ from flaxdiff.data.datasets import get_dataset_grain
155
+ from flaxdiff.utils import defaultTextEncodeModel
156
+ from flaxdiff.samplers.euler import EulerAncestralSampler
154
157
  import jax
158
+ import jax.numpy as jnp
155
159
  import optax
156
160
  from datetime import datetime
157
161
 
158
162
  BATCH_SIZE = 16
159
- IMAGE_SIZE = 64
163
+ IMAGE_SIZE = 128
160
164
 
161
165
  # Define noise scheduler
162
166
  edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
163
-
167
+ karas_ve_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
164
168
  # Define model
165
- unet = UNet(emb_features=256,
166
- feature_depths=[64, 128, 256, 512],
167
- attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}],
169
+ unet = Unet(emb_features=256,
170
+ feature_depths=[64, 64, 128, 256, 512],
171
+ attention_configs=[
172
+ None,
173
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
174
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
175
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
176
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":False, "use_self_and_cross":False}
177
+ ],
168
178
  num_res_blocks=2,
169
- num_middle_res_blocks=1)
170
-
179
+ num_middle_res_blocks=1
180
+ )
171
181
  # Load dataset
172
- data, datalen = get_dataset("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
182
+ data = get_dataset_grain("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
183
+ datalen = data['train_len']
173
184
  batches = datalen // BATCH_SIZE
174
185
 
186
+ input_shapes = {
187
+ "x": (IMAGE_SIZE, IMAGE_SIZE, 3),
188
+ "temb": (),
189
+ "textcontext": (77, 768)
190
+ }
191
+ text_encoder = defaultTextEncodeModel()
192
+
193
+ # Construct a validation set by the prompts
194
+ val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold']
195
+
196
+ def get_val_dataset(batch_size=8):
197
+ for i in range(0, len(val_prompts), batch_size):
198
+ prompts = val_prompts[i:i + batch_size]
199
+ tokens = text_encoder.tokenize(prompts)
200
+ yield tokens
201
+
202
+ data['test'] = get_val_dataset
203
+ data['test_len'] = len(val_prompts)
204
+
175
205
  # Define optimizer
176
206
  solver = optax.adam(2e-4)
177
207
 
178
208
  # Create trainer
179
- trainer = DiffusionTrainer(unet, optimizer=solver,
180
- noise_schedule=edm_schedule,
181
- rngs=jax.random.PRNGKey(4),
182
- name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
183
- model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data))
209
+ trainer = DiffusionTrainer(
210
+ unet, optimizer=solver,
211
+ input_shapes=input_shapes,
212
+ noise_schedule=edm_schedule,
213
+ rngs=jax.random.PRNGKey(4),
214
+ name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
215
+ model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data),
216
+ encoder=text_encoder,
217
+ distributed_training=True,
218
+ wandb_config = {
219
+ "project": 'mlops-msml605-project',
220
+ "name": f"prototype-{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}",
221
+ })
184
222
 
185
223
  # Train the model
186
- final_state = trainer.fit(data, batches, epochs=2000)
224
+ final_state = trainer.fit(data, batches, epochs=2000, sampler_class=EulerAncestralSampler, sampling_noise_schedule=karas_ve_schedule)
187
225
  ```
188
226
 
189
227
  ### Inference Example
@@ -301,8 +339,8 @@ Images generated by the following prompts using classifier free guidance with gu
301
339
  `Training Epochs: 1000`
302
340
  `Steps per epoch: 511`
303
341
 
304
- `Training Noise Schedule: CosineNoiseSchedule`
305
- `Inference Noise Schedule: CosineNoiseSchedule`
342
+ `Training Noise Schedule: CosineNoiseScheduler`
343
+ `Inference Noise Schedule: CosineNoiseScheduler`
306
344
 
307
345
  `Model: UNet(emb_features=256,
308
346
  feature_depths=[64, 128, 256, 512],
@@ -321,8 +359,8 @@ Images generated by the following prompts using classifier free guidance with gu
321
359
  `Training Epochs: 1000`
322
360
  `Steps per epoch: 511`
323
361
 
324
- `Training Noise Schedule: CosineNoiseSchedule`
325
- `Inference Noise Schedule: CosineNoiseSchedule`
362
+ `Training Noise Schedule: CosineNoiseScheduler`
363
+ `Inference Noise Schedule: CosineNoiseScheduler`
326
364
 
327
365
  `Model: UNet(emb_features=256,
328
366
  feature_depths=[64, 128, 256, 512],
@@ -27,10 +27,10 @@ flaxdiff/samplers/euler.py,sha256=Htb-IJeu7jSgY6mvgYr9yl9pUnos49vijlVk5IQsRps,27
27
27
  flaxdiff/samplers/heun_sampler.py,sha256=UyI-hSlyWvt-7VEUJj27zjgyzKkGVl8fDUHV-YpSOCc,1421
28
28
  flaxdiff/samplers/multistep_dpm.py,sha256=3Wu3MrMLYaBb1ObraTbWrJmtEtU0adl1dDbz5fPJ4Gs,2735
29
29
  flaxdiff/samplers/rk4_sampler.py,sha256=1j1pES_Q2QiaURvEWeedbbT1LHmkc3jsu0GgH83qBL0,1926
30
- flaxdiff/schedulers/__init__.py,sha256=3id390WEfdf-MN-oLSPAhlRFIXrFWr6ioAHPAwURJyE,375
30
+ flaxdiff/schedulers/__init__.py,sha256=EcABJ5UqsfeFXD9ypbgjVSYb6IKm7bcmsUbqEwVpHUc,376
31
31
  flaxdiff/schedulers/common.py,sha256=b-W4iI-aqScpVE8VZbBpiYvAVI6rqDkUP-C_hEVBwCI,4151
32
32
  flaxdiff/schedulers/continuous.py,sha256=5c_niOA20fxJ5oJDi09FfayIRogBGwtfG0XThW2IUZk,334
33
- flaxdiff/schedulers/cosine.py,sha256=9ban0dFHLMm35wQvaBT4nCQwPGmzNsXwQ1xI0oppmJI,2005
33
+ flaxdiff/schedulers/cosine.py,sha256=EtU3SjJaP9R9ULHNiYrX9jBLSsAGKPGteHiwOzWNzYo,2006
34
34
  flaxdiff/schedulers/discrete.py,sha256=O54wH2HVu3olJA71NxgAXFW9cr6B6Gl-DR_uZeytpds,3319
35
35
  flaxdiff/schedulers/exp.py,sha256=cPTnUJpYdzJRRZqMLYQz0rRUCpEmaP2tXhRumLx94jA,605
36
36
  flaxdiff/schedulers/karras.py,sha256=4GN120kGwdxxU-h2mVdhBVy9IORkUMm_vvz3XjthBcI,3355
@@ -41,7 +41,7 @@ flaxdiff/trainer/autoencoder_trainer.py,sha256=hxihkRL9WCIQVGOP-pc1jjjIUaRXDLcNo
41
41
  flaxdiff/trainer/diffusion_trainer.py,sha256=zde_nRzsC2GD5KNCn5Qjw9ldHi7L_-teJhcUNUDCdcQ,12815
42
42
  flaxdiff/trainer/simple_trainer.py,sha256=lmRo8N0bMupIyS3ejPvPtxoskY_3GLC8iyJE6u4TIWc,21990
43
43
  flaxdiff/trainer/video_diffusion_trainer.py,sha256=gMkKpnKNTo8QhTx5ptEEkc7W5-7rzXIr9queU53hXyQ,2197
44
- flaxdiff-0.1.36.4.dist-info/METADATA,sha256=MTgRu4VgbQaGqbGv_S3wXd_dzeNmHXnixRdvs93dWj0,22310
45
- flaxdiff-0.1.36.4.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
46
- flaxdiff-0.1.36.4.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
47
- flaxdiff-0.1.36.4.dist-info/RECORD,,
44
+ flaxdiff-0.1.36.5.dist-info/METADATA,sha256=Bk4FoPnJ0DlpficfpLQ9t0SaE13xe26xs2gEb2BYdfI,23985
45
+ flaxdiff-0.1.36.5.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
46
+ flaxdiff-0.1.36.5.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
47
+ flaxdiff-0.1.36.5.dist-info/RECORD,,