flaxdiff 0.1.36.3__py3-none-any.whl → 0.1.36.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,6 +4,8 @@ import grain.python as pygrain
4
4
  from flaxdiff.utils import AutoTextTokenizer
5
5
  from typing import Dict
6
6
  import random
7
+ import augmax
8
+ import jax
7
9
 
8
10
  # -----------------------------------------------------------------------------------------------#
9
11
  # Oxford flowers and other TFDS datasources -----------------------------------------------------#
@@ -47,6 +49,15 @@ def tfds_augmenters(image_scale, method):
47
49
  interpolation = cv2.INTER_CUBIC
48
50
  else:
49
51
  interpolation = cv2.INTER_AREA
52
+
53
+ augments = augmax.Chain(
54
+ augmax.HorizontalFlip(0.5),
55
+ augmax.RandomContrast((-0.05, 0.05), 1.),
56
+ augmax.RandomBrightness((-0.2, 0.2), 1.)
57
+ )
58
+
59
+ augments = jax.jit(augments, backend="cpu")
60
+
50
61
  class augmenters(pygrain.MapTransform):
51
62
  def __init__(self, *args, **kwargs):
52
63
  super().__init__(*args, **kwargs)
@@ -56,6 +67,7 @@ def tfds_augmenters(image_scale, method):
56
67
  image = element['image']
57
68
  image = cv2.resize(image, (image_scale, image_scale),
58
69
  interpolation=interpolation)
70
+ # image = augments(image)
59
71
  # image = (image - 127.5) / 127.5
60
72
  caption = labelizer(element)
61
73
  results = self.tokenize(caption)
@@ -1,6 +1,6 @@
1
1
  from .discrete import DiscreteNoiseScheduler
2
2
  from .common import NoiseScheduler, GeneralizedNoiseScheduler
3
- from .cosine import CosineNoiseSchedule, ContinuousNoiseScheduler, CosineGeneralNoiseScheduler
3
+ from .cosine import CosineNoiseScheduler, ContinuousNoiseScheduler, CosineGeneralNoiseScheduler
4
4
  from .linear import LinearNoiseSchedule
5
5
  from .sqrt import SqrtContinuousNoiseScheduler
6
6
  from .karras import KarrasVENoiseScheduler, SimpleExpNoiseScheduler, EDMNoiseScheduler
@@ -12,7 +12,7 @@ def cosine_beta_schedule(timesteps, start_angle=0.008, end_angle=0.999):
12
12
  betas = 1 - (alphas_bar[1:] / alphas_bar[:-1])
13
13
  return np.clip(betas, 0, end_angle)
14
14
 
15
- class CosineNoiseSchedule(DiscreteNoiseScheduler):
15
+ class CosineNoiseScheduler(DiscreteNoiseScheduler):
16
16
  def __init__(self, timesteps, beta_start=0.008, beta_end=0.999, *args, **kwargs):
17
17
  super().__init__(timesteps, beta_start, beta_end, schedule_fn=cosine_beta_schedule, *args, **kwargs)
18
18
 
@@ -14,6 +14,7 @@ from typing import Dict, Callable, Sequence, Any, Union, Tuple, Type
14
14
  from ..schedulers import NoiseScheduler
15
15
  from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
16
16
  from ..samplers.common import DiffusionSampler
17
+ from ..samplers.ddim import DDIMSampler
17
18
 
18
19
  from flaxdiff.utils import RandomMarkovState
19
20
 
@@ -179,9 +180,6 @@ class DiffusionTrainer(SimpleTrainer):
179
180
  nloss = loss_fn(preds, expected_output)
180
181
  # Ignore the loss contribution of images with zero standard deviation
181
182
  nloss *= noise_schedule.get_weights(noise_level)
182
- # nloss = jnp.mean(nloss, axis=(1,2,3))
183
- # nloss = jnp.where(is_non_zero, nloss, 0)
184
- # nloss = jnp.mean(nloss, where=nloss != 0)
185
183
  nloss = jnp.mean(nloss)
186
184
  loss = nloss
187
185
  return loss
@@ -224,11 +222,11 @@ class DiffusionTrainer(SimpleTrainer):
224
222
  if distributed_training:
225
223
  train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')),
226
224
  out_specs=(P(), P(), P()))
227
- train_step = jax.jit(train_step)
225
+ train_step = jax.jit(train_step)
228
226
 
229
227
  return train_step
230
228
 
231
- def _define_vaidation_step(self, sampler_class: Type[DiffusionSampler]):
229
+ def _define_vaidation_step(self, sampler_class: Type[DiffusionSampler]=DDIMSampler, sampling_noise_schedule: NoiseScheduler=None):
232
230
  model = self.model
233
231
  encoder = self.encoder
234
232
  autoencoder = self.autoencoder
@@ -241,7 +239,7 @@ class DiffusionTrainer(SimpleTrainer):
241
239
  sampler = sampler_class(
242
240
  model=model,
243
241
  params=state.ema_params,
244
- noise_schedule=self.noise_schedule,
242
+ noise_schedule=self.noise_schedule if sampling_noise_schedule is None else sampling_noise_schedule,
245
243
  model_output_transform=self.model_output_transform,
246
244
  image_size=self.input_shapes['x'][0],
247
245
  null_labels_seq=null_labels_full,
@@ -311,10 +309,11 @@ class DiffusionTrainer(SimpleTrainer):
311
309
  print("Error logging images to wandb", e)
312
310
  traceback.print_exc()
313
311
 
314
- def fit(self, data, training_steps_per_epoch, epochs, val_steps_per_epoch=8, sampler_class=None):
312
+ def fit(self, data, training_steps_per_epoch, epochs, val_steps_per_epoch=8, sampler_class: Type[DiffusionSampler]=DDIMSampler, sampling_noise_schedule: NoiseScheduler=None):
315
313
  local_batch_size = data['local_batch_size']
316
314
  validation_step_args = {
317
315
  "sampler_class": sampler_class,
316
+ "sampling_noise_schedule": sampling_noise_schedule,
318
317
  }
319
318
  super().fit(
320
319
  data,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.36.3
3
+ Version: 0.1.36.5
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -96,7 +96,7 @@ Also, few of the text may be generated with help of github copilot, so please ex
96
96
  ### Schedulers
97
97
  Implemented in `flaxdiff.schedulers`:
98
98
  - **LinearNoiseSchedule** (`flaxdiff.schedulers.LinearNoiseSchedule`): A beta-parameterized discrete scheduler.
99
- - **CosineNoiseSchedule** (`flaxdiff.schedulers.CosineNoiseSchedule`): A beta-parameterized discrete scheduler.
99
+ - **CosineNoiseScheduler** (`flaxdiff.schedulers.CosineNoiseScheduler`): A beta-parameterized discrete scheduler.
100
100
  - **ExpNoiseSchedule** (`flaxdiff.schedulers.ExpNoiseSchedule`): A beta-parameterized discrete scheduler.
101
101
  - **CosineContinuousNoiseScheduler** (`flaxdiff.schedulers.CosineContinuousNoiseScheduler`): A continuous scheduler.
102
102
  - **CosineGeneralNoiseScheduler** (`flaxdiff.schedulers.CosineGeneralNoiseScheduler`): A continuous sigma parameterized cosine scheduler.
@@ -147,43 +147,81 @@ sticking to the versions mentioned in the requirements.txt
147
147
  Here is a simplified example to get you started with training a diffusion model using FlaxDiff:
148
148
 
149
149
  ```python
150
- from flaxdiff.schedulers import EDMNoiseScheduler
150
+ from flaxdiff.schedulers import EDMNoiseScheduler, KarrasVENoiseScheduler
151
151
  from flaxdiff.predictors import KarrasPredictionTransform
152
- from flaxdiff.models.simple_unet import SimpleUNet as UNet
152
+ from flaxdiff.models.simple_unet import Unet
153
153
  from flaxdiff.trainer import DiffusionTrainer
154
+ from flaxdiff.data.datasets import get_dataset_grain
155
+ from flaxdiff.utils import defaultTextEncodeModel
156
+ from flaxdiff.samplers.euler import EulerAncestralSampler
154
157
  import jax
158
+ import jax.numpy as jnp
155
159
  import optax
156
160
  from datetime import datetime
157
161
 
158
162
  BATCH_SIZE = 16
159
- IMAGE_SIZE = 64
163
+ IMAGE_SIZE = 128
160
164
 
161
165
  # Define noise scheduler
162
166
  edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
163
-
167
+ karas_ve_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
164
168
  # Define model
165
- unet = UNet(emb_features=256,
166
- feature_depths=[64, 128, 256, 512],
167
- attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}],
169
+ unet = Unet(emb_features=256,
170
+ feature_depths=[64, 64, 128, 256, 512],
171
+ attention_configs=[
172
+ None,
173
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
174
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
175
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
176
+ {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":False, "use_self_and_cross":False}
177
+ ],
168
178
  num_res_blocks=2,
169
- num_middle_res_blocks=1)
170
-
179
+ num_middle_res_blocks=1
180
+ )
171
181
  # Load dataset
172
- data, datalen = get_dataset("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
182
+ data = get_dataset_grain("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
183
+ datalen = data['train_len']
173
184
  batches = datalen // BATCH_SIZE
174
185
 
186
+ input_shapes = {
187
+ "x": (IMAGE_SIZE, IMAGE_SIZE, 3),
188
+ "temb": (),
189
+ "textcontext": (77, 768)
190
+ }
191
+ text_encoder = defaultTextEncodeModel()
192
+
193
+ # Construct a validation set by the prompts
194
+ val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold']
195
+
196
+ def get_val_dataset(batch_size=8):
197
+ for i in range(0, len(val_prompts), batch_size):
198
+ prompts = val_prompts[i:i + batch_size]
199
+ tokens = text_encoder.tokenize(prompts)
200
+ yield tokens
201
+
202
+ data['test'] = get_val_dataset
203
+ data['test_len'] = len(val_prompts)
204
+
175
205
  # Define optimizer
176
206
  solver = optax.adam(2e-4)
177
207
 
178
208
  # Create trainer
179
- trainer = DiffusionTrainer(unet, optimizer=solver,
180
- noise_schedule=edm_schedule,
181
- rngs=jax.random.PRNGKey(4),
182
- name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
183
- model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data))
209
+ trainer = DiffusionTrainer(
210
+ unet, optimizer=solver,
211
+ input_shapes=input_shapes,
212
+ noise_schedule=edm_schedule,
213
+ rngs=jax.random.PRNGKey(4),
214
+ name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
215
+ model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data),
216
+ encoder=text_encoder,
217
+ distributed_training=True,
218
+ wandb_config = {
219
+ "project": 'mlops-msml605-project',
220
+ "name": f"prototype-{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}",
221
+ })
184
222
 
185
223
  # Train the model
186
- final_state = trainer.fit(data, batches, epochs=2000)
224
+ final_state = trainer.fit(data, batches, epochs=2000, sampler_class=EulerAncestralSampler, sampling_noise_schedule=karas_ve_schedule)
187
225
  ```
188
226
 
189
227
  ### Inference Example
@@ -301,8 +339,8 @@ Images generated by the following prompts using classifier free guidance with gu
301
339
  `Training Epochs: 1000`
302
340
  `Steps per epoch: 511`
303
341
 
304
- `Training Noise Schedule: CosineNoiseSchedule`
305
- `Inference Noise Schedule: CosineNoiseSchedule`
342
+ `Training Noise Schedule: CosineNoiseScheduler`
343
+ `Inference Noise Schedule: CosineNoiseScheduler`
306
344
 
307
345
  `Model: UNet(emb_features=256,
308
346
  feature_depths=[64, 128, 256, 512],
@@ -321,8 +359,8 @@ Images generated by the following prompts using classifier free guidance with gu
321
359
  `Training Epochs: 1000`
322
360
  `Steps per epoch: 511`
323
361
 
324
- `Training Noise Schedule: CosineNoiseSchedule`
325
- `Inference Noise Schedule: CosineNoiseSchedule`
362
+ `Training Noise Schedule: CosineNoiseScheduler`
363
+ `Inference Noise Schedule: CosineNoiseScheduler`
326
364
 
327
365
  `Model: UNet(emb_features=256,
328
366
  feature_depths=[64, 128, 256, 512],
@@ -5,7 +5,7 @@ flaxdiff/data/dataset_map.py,sha256=hcHaoR2IbNQmfyPUhYd6_8xinurxxCqawQijAsDI0Ek,
5
5
  flaxdiff/data/datasets.py,sha256=YUMoSvF2yAyikRvRofZVlHwfEOU3zXSSG4KkLnVfpoA,5626
6
6
  flaxdiff/data/online_loader.py,sha256=1Fi_QRixxRzbt602nORINcDeHEccvCrBpagrz4PURYg,12499
7
7
  flaxdiff/data/sources/gcs.py,sha256=11ZuQhvMyJRLg21DgVdzO5qEuae7zgzTXGNOskF-cbs,3380
8
- flaxdiff/data/sources/tfds.py,sha256=WA3h9lyR4yotCNEmJON2noIN-2HNcqhf6zigx1XXsMI,2481
8
+ flaxdiff/data/sources/tfds.py,sha256=7n-uobG_UvkD5mU_1ovPd9kb6xJrbEKFFXdVEHDunts,2781
9
9
  flaxdiff/metrics/inception.py,sha256=a5kjMCPMT9gB88c_HCKiek-2vsAyoE35K7nDt4h4pVI,31843
10
10
  flaxdiff/metrics/utils.py,sha256=YuuOfqvqgIjsceupwNeJ59vQ2TnGeNMIyKdkIqOmoNg,1702
11
11
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
@@ -27,10 +27,10 @@ flaxdiff/samplers/euler.py,sha256=Htb-IJeu7jSgY6mvgYr9yl9pUnos49vijlVk5IQsRps,27
27
27
  flaxdiff/samplers/heun_sampler.py,sha256=UyI-hSlyWvt-7VEUJj27zjgyzKkGVl8fDUHV-YpSOCc,1421
28
28
  flaxdiff/samplers/multistep_dpm.py,sha256=3Wu3MrMLYaBb1ObraTbWrJmtEtU0adl1dDbz5fPJ4Gs,2735
29
29
  flaxdiff/samplers/rk4_sampler.py,sha256=1j1pES_Q2QiaURvEWeedbbT1LHmkc3jsu0GgH83qBL0,1926
30
- flaxdiff/schedulers/__init__.py,sha256=3id390WEfdf-MN-oLSPAhlRFIXrFWr6ioAHPAwURJyE,375
30
+ flaxdiff/schedulers/__init__.py,sha256=EcABJ5UqsfeFXD9ypbgjVSYb6IKm7bcmsUbqEwVpHUc,376
31
31
  flaxdiff/schedulers/common.py,sha256=b-W4iI-aqScpVE8VZbBpiYvAVI6rqDkUP-C_hEVBwCI,4151
32
32
  flaxdiff/schedulers/continuous.py,sha256=5c_niOA20fxJ5oJDi09FfayIRogBGwtfG0XThW2IUZk,334
33
- flaxdiff/schedulers/cosine.py,sha256=9ban0dFHLMm35wQvaBT4nCQwPGmzNsXwQ1xI0oppmJI,2005
33
+ flaxdiff/schedulers/cosine.py,sha256=EtU3SjJaP9R9ULHNiYrX9jBLSsAGKPGteHiwOzWNzYo,2006
34
34
  flaxdiff/schedulers/discrete.py,sha256=O54wH2HVu3olJA71NxgAXFW9cr6B6Gl-DR_uZeytpds,3319
35
35
  flaxdiff/schedulers/exp.py,sha256=cPTnUJpYdzJRRZqMLYQz0rRUCpEmaP2tXhRumLx94jA,605
36
36
  flaxdiff/schedulers/karras.py,sha256=4GN120kGwdxxU-h2mVdhBVy9IORkUMm_vvz3XjthBcI,3355
@@ -38,10 +38,10 @@ flaxdiff/schedulers/linear.py,sha256=6003F5ISq1Wc0h6UAzY95MJgsDIKGMhBzbiVALpea0k
38
38
  flaxdiff/schedulers/sqrt.py,sha256=1F84ZgQPuoNMhe6yxGTR2G0h7dPOZtm4UDQOakbSsEU,445
39
39
  flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,128
40
40
  flaxdiff/trainer/autoencoder_trainer.py,sha256=hxihkRL9WCIQVGOP-pc1jjjIUaRXDLcNo3_erTKsuWM,7049
41
- flaxdiff/trainer/diffusion_trainer.py,sha256=ajOWBgFFwXP_VQScUjcuPoaB4Gk02aF0Ls5LNlA8wqA,12691
41
+ flaxdiff/trainer/diffusion_trainer.py,sha256=zde_nRzsC2GD5KNCn5Qjw9ldHi7L_-teJhcUNUDCdcQ,12815
42
42
  flaxdiff/trainer/simple_trainer.py,sha256=lmRo8N0bMupIyS3ejPvPtxoskY_3GLC8iyJE6u4TIWc,21990
43
43
  flaxdiff/trainer/video_diffusion_trainer.py,sha256=gMkKpnKNTo8QhTx5ptEEkc7W5-7rzXIr9queU53hXyQ,2197
44
- flaxdiff-0.1.36.3.dist-info/METADATA,sha256=9XaZMJ6SMFP7OUn2tp9v5FQveMGoxvuiyxdJ8SmMd8w,22310
45
- flaxdiff-0.1.36.3.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
46
- flaxdiff-0.1.36.3.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
47
- flaxdiff-0.1.36.3.dist-info/RECORD,,
44
+ flaxdiff-0.1.36.5.dist-info/METADATA,sha256=Bk4FoPnJ0DlpficfpLQ9t0SaE13xe26xs2gEb2BYdfI,23985
45
+ flaxdiff-0.1.36.5.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
46
+ flaxdiff-0.1.36.5.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
47
+ flaxdiff-0.1.36.5.dist-info/RECORD,,