flaxdiff 0.1.36.3__tar.gz → 0.1.36.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/PKG-INFO +1 -1
  2. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/data/sources/tfds.py +12 -0
  3. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/trainer/diffusion_trainer.py +6 -7
  4. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff.egg-info/PKG-INFO +1 -1
  5. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/pyproject.toml +1 -1
  6. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/README.md +0 -0
  7. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/__init__.py +0 -0
  8. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/data/__init__.py +0 -0
  9. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/data/dataset_map.py +0 -0
  10. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/data/datasets.py +0 -0
  11. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/data/online_loader.py +0 -0
  12. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/data/sources/gcs.py +0 -0
  13. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/metrics/inception.py +0 -0
  14. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/metrics/utils.py +0 -0
  15. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/models/__init__.py +0 -0
  16. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/models/attention.py +0 -0
  17. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/models/autoencoder/__init__.py +0 -0
  18. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
  19. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/models/autoencoder/diffusers.py +0 -0
  20. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
  21. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/models/common.py +0 -0
  22. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/models/favor_fastattn.py +0 -0
  23. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/models/simple_unet.py +0 -0
  24. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/models/simple_vit.py +0 -0
  25. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/predictors/__init__.py +0 -0
  26. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/samplers/__init__.py +0 -0
  27. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/samplers/common.py +0 -0
  28. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/samplers/ddim.py +0 -0
  29. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/samplers/ddpm.py +0 -0
  30. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/samplers/euler.py +0 -0
  31. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/samplers/heun_sampler.py +0 -0
  32. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/samplers/multistep_dpm.py +0 -0
  33. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/samplers/rk4_sampler.py +0 -0
  34. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/schedulers/__init__.py +0 -0
  35. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/schedulers/common.py +0 -0
  36. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/schedulers/continuous.py +0 -0
  37. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/schedulers/cosine.py +0 -0
  38. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/schedulers/discrete.py +0 -0
  39. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/schedulers/exp.py +0 -0
  40. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/schedulers/karras.py +0 -0
  41. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/schedulers/linear.py +0 -0
  42. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/schedulers/sqrt.py +0 -0
  43. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/trainer/__init__.py +0 -0
  44. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
  45. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/trainer/simple_trainer.py +0 -0
  46. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/trainer/video_diffusion_trainer.py +0 -0
  47. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff/utils.py +0 -0
  48. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff.egg-info/SOURCES.txt +0 -0
  49. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff.egg-info/dependency_links.txt +0 -0
  50. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff.egg-info/requires.txt +0 -0
  51. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/flaxdiff.egg-info/top_level.txt +0 -0
  52. {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.4}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.36.3
3
+ Version: 0.1.36.4
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -4,6 +4,8 @@ import grain.python as pygrain
4
4
  from flaxdiff.utils import AutoTextTokenizer
5
5
  from typing import Dict
6
6
  import random
7
+ import augmax
8
+ import jax
7
9
 
8
10
  # -----------------------------------------------------------------------------------------------#
9
11
  # Oxford flowers and other TFDS datasources -----------------------------------------------------#
@@ -47,6 +49,15 @@ def tfds_augmenters(image_scale, method):
47
49
  interpolation = cv2.INTER_CUBIC
48
50
  else:
49
51
  interpolation = cv2.INTER_AREA
52
+
53
+ augments = augmax.Chain(
54
+ augmax.HorizontalFlip(0.5),
55
+ augmax.RandomContrast((-0.05, 0.05), 1.),
56
+ augmax.RandomBrightness((-0.2, 0.2), 1.)
57
+ )
58
+
59
+ augments = jax.jit(augments, backend="cpu")
60
+
50
61
  class augmenters(pygrain.MapTransform):
51
62
  def __init__(self, *args, **kwargs):
52
63
  super().__init__(*args, **kwargs)
@@ -56,6 +67,7 @@ def tfds_augmenters(image_scale, method):
56
67
  image = element['image']
57
68
  image = cv2.resize(image, (image_scale, image_scale),
58
69
  interpolation=interpolation)
70
+ # image = augments(image)
59
71
  # image = (image - 127.5) / 127.5
60
72
  caption = labelizer(element)
61
73
  results = self.tokenize(caption)
@@ -14,6 +14,7 @@ from typing import Dict, Callable, Sequence, Any, Union, Tuple, Type
14
14
  from ..schedulers import NoiseScheduler
15
15
  from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
16
16
  from ..samplers.common import DiffusionSampler
17
+ from ..samplers.ddim import DDIMSampler
17
18
 
18
19
  from flaxdiff.utils import RandomMarkovState
19
20
 
@@ -179,9 +180,6 @@ class DiffusionTrainer(SimpleTrainer):
179
180
  nloss = loss_fn(preds, expected_output)
180
181
  # Ignore the loss contribution of images with zero standard deviation
181
182
  nloss *= noise_schedule.get_weights(noise_level)
182
- # nloss = jnp.mean(nloss, axis=(1,2,3))
183
- # nloss = jnp.where(is_non_zero, nloss, 0)
184
- # nloss = jnp.mean(nloss, where=nloss != 0)
185
183
  nloss = jnp.mean(nloss)
186
184
  loss = nloss
187
185
  return loss
@@ -224,11 +222,11 @@ class DiffusionTrainer(SimpleTrainer):
224
222
  if distributed_training:
225
223
  train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')),
226
224
  out_specs=(P(), P(), P()))
227
- train_step = jax.jit(train_step)
225
+ train_step = jax.jit(train_step)
228
226
 
229
227
  return train_step
230
228
 
231
- def _define_vaidation_step(self, sampler_class: Type[DiffusionSampler]):
229
+ def _define_vaidation_step(self, sampler_class: Type[DiffusionSampler]=DDIMSampler, sampling_noise_schedule: NoiseScheduler=None):
232
230
  model = self.model
233
231
  encoder = self.encoder
234
232
  autoencoder = self.autoencoder
@@ -241,7 +239,7 @@ class DiffusionTrainer(SimpleTrainer):
241
239
  sampler = sampler_class(
242
240
  model=model,
243
241
  params=state.ema_params,
244
- noise_schedule=self.noise_schedule,
242
+ noise_schedule=self.noise_schedule if sampling_noise_schedule is None else sampling_noise_schedule,
245
243
  model_output_transform=self.model_output_transform,
246
244
  image_size=self.input_shapes['x'][0],
247
245
  null_labels_seq=null_labels_full,
@@ -311,10 +309,11 @@ class DiffusionTrainer(SimpleTrainer):
311
309
  print("Error logging images to wandb", e)
312
310
  traceback.print_exc()
313
311
 
314
- def fit(self, data, training_steps_per_epoch, epochs, val_steps_per_epoch=8, sampler_class=None):
312
+ def fit(self, data, training_steps_per_epoch, epochs, val_steps_per_epoch=8, sampler_class: Type[DiffusionSampler]=DDIMSampler, sampling_noise_schedule: NoiseScheduler=None):
315
313
  local_batch_size = data['local_batch_size']
316
314
  validation_step_args = {
317
315
  "sampler_class": sampler_class,
316
+ "sampling_noise_schedule": sampling_noise_schedule,
318
317
  }
319
318
  super().fit(
320
319
  data,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.36.3
3
+ Version: 0.1.36.4
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "flaxdiff"
7
- version = "0.1.36.3"
7
+ version = "0.1.36.4"
8
8
  description = "A versatile and easy to understand Diffusion library"
9
9
  readme = "README.md"
10
10
  authors = [
File without changes
File without changes