flaxdiff 0.1.23__tar.gz → 0.1.24__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/PKG-INFO +1 -1
  2. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/models/simple_vit.py +6 -3
  3. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff.egg-info/PKG-INFO +1 -1
  4. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/setup.py +1 -1
  5. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/README.md +0 -0
  6. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/__init__.py +0 -0
  7. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/data/__init__.py +0 -0
  8. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/data/online_loader.py +0 -0
  9. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/models/__init__.py +0 -0
  10. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/models/attention.py +0 -0
  11. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/models/autoencoder/__init__.py +0 -0
  12. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
  13. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/models/autoencoder/diffusers.py +0 -0
  14. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
  15. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/models/common.py +0 -0
  16. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/models/favor_fastattn.py +0 -0
  17. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/models/simple_unet.py +0 -0
  18. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/predictors/__init__.py +0 -0
  19. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/samplers/__init__.py +0 -0
  20. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/samplers/common.py +0 -0
  21. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/samplers/ddim.py +0 -0
  22. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/samplers/ddpm.py +0 -0
  23. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/samplers/euler.py +0 -0
  24. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/samplers/heun_sampler.py +0 -0
  25. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/samplers/multistep_dpm.py +0 -0
  26. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/samplers/rk4_sampler.py +0 -0
  27. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/schedulers/__init__.py +0 -0
  28. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/schedulers/common.py +0 -0
  29. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/schedulers/continuous.py +0 -0
  30. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/schedulers/cosine.py +0 -0
  31. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/schedulers/discrete.py +0 -0
  32. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/schedulers/exp.py +0 -0
  33. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/schedulers/karras.py +0 -0
  34. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/schedulers/linear.py +0 -0
  35. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/schedulers/sqrt.py +0 -0
  36. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/trainer/__init__.py +0 -0
  37. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
  38. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/trainer/diffusion_trainer.py +0 -0
  39. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/trainer/simple_trainer.py +0 -0
  40. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff/utils.py +0 -0
  41. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff.egg-info/SOURCES.txt +0 -0
  42. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff.egg-info/dependency_links.txt +0 -0
  43. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff.egg-info/requires.txt +0 -0
  44. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/flaxdiff.egg-info/top_level.txt +0 -0
  45. {flaxdiff-0.1.23 → flaxdiff-0.1.24}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.23
3
+ Version: 0.1.24
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -58,6 +58,9 @@ class UViT(nn.Module):
58
58
  dtype: Any = jnp.float32
59
59
  precision: Any = jax.lax.Precision.HIGH
60
60
  use_projection: bool = False
61
+ use_flash_attention: bool = False
62
+ use_self_and_cross: bool = False
63
+ force_fp32_for_softmax: bool = True
61
64
  activation:Callable = jax.nn.swish
62
65
  norm_groups:int=8
63
66
  dtype: Optional[Dtype] = None
@@ -102,7 +105,7 @@ class UViT(nn.Module):
102
105
  for i in range(self.num_layers // 2):
103
106
  x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
104
107
  dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
105
- use_flash_attention=False, use_self_and_cross=False, force_fp32_for_softmax=True,
108
+ use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
106
109
  only_pure_attention=False,
107
110
  kernel_init=self.kernel_init())(x)
108
111
  skips.append(x)
@@ -110,7 +113,7 @@ class UViT(nn.Module):
110
113
  # Middle block
111
114
  x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
112
115
  dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
113
- use_flash_attention=False, use_self_and_cross=True, force_fp32_for_softmax=True,
116
+ use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.fforce_fp32_for_softmax,
114
117
  only_pure_attention=False,
115
118
  kernel_init=self.kernel_init())(x)
116
119
 
@@ -121,7 +124,7 @@ class UViT(nn.Module):
121
124
  dtype=self.dtype, precision=self.precision)(skip)
122
125
  x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
123
126
  dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
124
- use_flash_attention=False, use_self_and_cross=False, force_fp32_for_softmax=True,
127
+ use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.fforce_fp32_for_softmax,
125
128
  only_pure_attention=False,
126
129
  kernel_init=self.kernel_init())(skip)
127
130
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.23
3
+ Version: 0.1.24
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -11,7 +11,7 @@ required_packages=[
11
11
  setup(
12
12
  name='flaxdiff',
13
13
  packages=find_packages(),
14
- version='0.1.23',
14
+ version='0.1.24',
15
15
  description='A versatile and easy to understand Diffusion library',
16
16
  long_description=open('README.md').read(),
17
17
  long_description_content_type='text/markdown',
File without changes
File without changes
File without changes