flaxdiff 0.1.35.4__tar.gz → 0.1.35.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/PKG-INFO +1 -1
  2. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/models/attention.py +12 -6
  3. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/models/simple_unet.py +6 -0
  4. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/models/simple_vit.py +10 -2
  5. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff.egg-info/PKG-INFO +1 -1
  6. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/setup.py +1 -1
  7. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/README.md +0 -0
  8. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/__init__.py +0 -0
  9. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/data/__init__.py +0 -0
  10. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/data/online_loader.py +0 -0
  11. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/models/__init__.py +0 -0
  12. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/models/autoencoder/__init__.py +0 -0
  13. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
  14. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/models/autoencoder/diffusers.py +0 -0
  15. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
  16. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/models/common.py +0 -0
  17. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/models/favor_fastattn.py +0 -0
  18. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/predictors/__init__.py +0 -0
  19. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/samplers/__init__.py +0 -0
  20. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/samplers/common.py +0 -0
  21. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/samplers/ddim.py +0 -0
  22. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/samplers/ddpm.py +0 -0
  23. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/samplers/euler.py +0 -0
  24. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/samplers/heun_sampler.py +0 -0
  25. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/samplers/multistep_dpm.py +0 -0
  26. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/samplers/rk4_sampler.py +0 -0
  27. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/schedulers/__init__.py +0 -0
  28. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/schedulers/common.py +0 -0
  29. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/schedulers/continuous.py +0 -0
  30. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/schedulers/cosine.py +0 -0
  31. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/schedulers/discrete.py +0 -0
  32. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/schedulers/exp.py +0 -0
  33. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/schedulers/karras.py +0 -0
  34. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/schedulers/linear.py +0 -0
  35. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/schedulers/sqrt.py +0 -0
  36. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/trainer/__init__.py +0 -0
  37. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
  38. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/trainer/diffusion_trainer.py +0 -0
  39. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/trainer/simple_trainer.py +0 -0
  40. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff/utils.py +0 -0
  41. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff.egg-info/SOURCES.txt +0 -0
  42. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff.egg-info/dependency_links.txt +0 -0
  43. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff.egg-info/requires.txt +0 -0
  44. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/flaxdiff.egg-info/top_level.txt +0 -0
  45. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.5}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.35.4
3
+ Version: 0.1.35.5
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -303,27 +303,30 @@ class TransformerBlock(nn.Module):
303
303
  only_pure_attention:bool = False
304
304
  force_fp32_for_softmax: bool = True
305
305
  kernel_init: Callable = kernel_init(1.0)
306
+ norm_inputs: bool = True
307
+ explicitly_add_residual: bool = True
306
308
 
307
309
  @nn.compact
308
310
  def __call__(self, x, context=None):
309
311
  inner_dim = self.heads * self.dim_head
310
312
  C = x.shape[-1]
311
- normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
313
+ if self.norm_inputs:
314
+ x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
312
315
  if self.use_projection == True:
313
316
  if self.use_linear_attention:
314
317
  projected_x = nn.Dense(features=inner_dim,
315
318
  use_bias=False, precision=self.precision,
316
319
  kernel_init=self.kernel_init,
317
- dtype=self.dtype, name=f'project_in')(normed_x)
320
+ dtype=self.dtype, name=f'project_in')(x)
318
321
  else:
319
322
  projected_x = nn.Conv(
320
323
  features=inner_dim, kernel_size=(1, 1),
321
324
  kernel_init=self.kernel_init,
322
325
  strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
323
326
  precision=self.precision, name=f'project_in_conv',
324
- )(normed_x)
327
+ )(x)
325
328
  else:
326
- projected_x = normed_x
329
+ projected_x = x
327
330
  inner_dim = C
328
331
 
329
332
  context = projected_x if context is None else context
@@ -356,6 +359,9 @@ class TransformerBlock(nn.Module):
356
359
  strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
357
360
  precision=self.precision, name=f'project_out_conv',
358
361
  )(projected_x)
359
-
360
- out = x + projected_x
362
+
363
+ if self.only_pure_attention or self.explicitly_add_residual:
364
+ projected_x = x + projected_x
365
+
366
+ out = projected_x
361
367
  return out
@@ -83,6 +83,8 @@ class Unet(nn.Module):
83
83
  precision=attention_config.get("precision", self.precision),
84
84
  only_pure_attention=attention_config.get("only_pure_attention", True),
85
85
  force_fp32_for_softmax=attention_config.get("force_fp32_for_softmax", False),
86
+ norm_inputs=attention_config.get("norm_inputs", True),
87
+ explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
86
88
  kernel_init=self.kernel_init(1.0),
87
89
  name=f"down_{i}_attention_{j}")(x, textcontext)
88
90
  # print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
@@ -125,6 +127,8 @@ class Unet(nn.Module):
125
127
  precision=middle_attention.get("precision", self.precision),
126
128
  only_pure_attention=middle_attention.get("only_pure_attention", True),
127
129
  force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
130
+ norm_inputs=middle_attention.get("norm_inputs", True),
131
+ explicitly_add_residual=middle_attention.get("explicitly_add_residual", True),
128
132
  kernel_init=self.kernel_init(1.0),
129
133
  name=f"middle_attention_{j}")(x, textcontext)
130
134
  x = ResidualBlock(
@@ -171,6 +175,8 @@ class Unet(nn.Module):
171
175
  precision=attention_config.get("precision", self.precision),
172
176
  only_pure_attention=attention_config.get("only_pure_attention", True),
173
177
  force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
178
+ norm_inputs=attention_config.get("norm_inputs", True),
179
+ explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
174
180
  kernel_init=self.kernel_init(1.0),
175
181
  name=f"up_{i}_attention_{j}")(x, textcontext)
176
182
  # print("Upscaling ", i, x.shape)
@@ -69,6 +69,8 @@ class UViT(nn.Module):
69
69
  precision: PrecisionLike = None
70
70
  kernel_init: Callable = partial(kernel_init, scale=1.0)
71
71
  add_residualblock_output: bool = False
72
+ norm_inputs: bool = False
73
+ explicitly_add_residual: bool = False
72
74
 
73
75
  def setup(self):
74
76
  if self.norm_groups > 0:
@@ -110,16 +112,20 @@ class UViT(nn.Module):
110
112
  for i in range(self.num_layers // 2):
111
113
  x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
112
114
  dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
113
- use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
115
+ use_flash_attention=self.use_flash_attention, use_self_and_cross=False, force_fp32_for_softmax=self.force_fp32_for_softmax,
114
116
  only_pure_attention=False,
117
+ norm_inputs=self.norm_inputs,
118
+ explicitly_add_residual=self.explicitly_add_residual,
115
119
  kernel_init=self.kernel_init())(x)
116
120
  skips.append(x)
117
121
 
118
122
  # Middle block
119
123
  x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
120
124
  dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
121
- use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
125
+ use_flash_attention=self.use_flash_attention, use_self_and_cross=False, force_fp32_for_softmax=self.force_fp32_for_softmax,
122
126
  only_pure_attention=False,
127
+ norm_inputs=self.norm_inputs,
128
+ explicitly_add_residual=self.explicitly_add_residual,
123
129
  kernel_init=self.kernel_init())(x)
124
130
 
125
131
  # # Out blocks
@@ -131,6 +137,8 @@ class UViT(nn.Module):
131
137
  dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
132
138
  use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
133
139
  only_pure_attention=False,
140
+ norm_inputs=self.norm_inputs,
141
+ explicitly_add_residual=self.explicitly_add_residual,
134
142
  kernel_init=self.kernel_init())(x)
135
143
 
136
144
  # print(f'Shape of x after transformer blocks: {x.shape}')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.35.4
3
+ Version: 0.1.35.5
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -11,7 +11,7 @@ required_packages=[
11
11
  setup(
12
12
  name='flaxdiff',
13
13
  packages=find_packages(),
14
- version='0.1.35.4',
14
+ version='0.1.35.5',
15
15
  description='A versatile and easy to understand Diffusion library',
16
16
  long_description=open('README.md').read(),
17
17
  long_description_content_type='text/markdown',
File without changes
File without changes