flaxdiff 0.1.35.4__tar.gz → 0.1.35.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/PKG-INFO +1 -1
  2. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/models/attention.py +13 -6
  3. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/models/simple_unet.py +17 -11
  4. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/models/simple_vit.py +10 -2
  5. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff.egg-info/PKG-INFO +1 -1
  6. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/setup.py +1 -1
  7. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/README.md +0 -0
  8. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/__init__.py +0 -0
  9. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/data/__init__.py +0 -0
  10. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/data/online_loader.py +0 -0
  11. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/models/__init__.py +0 -0
  12. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/models/autoencoder/__init__.py +0 -0
  13. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
  14. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/models/autoencoder/diffusers.py +0 -0
  15. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
  16. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/models/common.py +0 -0
  17. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/models/favor_fastattn.py +0 -0
  18. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/predictors/__init__.py +0 -0
  19. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/samplers/__init__.py +0 -0
  20. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/samplers/common.py +0 -0
  21. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/samplers/ddim.py +0 -0
  22. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/samplers/ddpm.py +0 -0
  23. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/samplers/euler.py +0 -0
  24. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/samplers/heun_sampler.py +0 -0
  25. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/samplers/multistep_dpm.py +0 -0
  26. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/samplers/rk4_sampler.py +0 -0
  27. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/schedulers/__init__.py +0 -0
  28. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/schedulers/common.py +0 -0
  29. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/schedulers/continuous.py +0 -0
  30. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/schedulers/cosine.py +0 -0
  31. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/schedulers/discrete.py +0 -0
  32. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/schedulers/exp.py +0 -0
  33. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/schedulers/karras.py +0 -0
  34. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/schedulers/linear.py +0 -0
  35. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/schedulers/sqrt.py +0 -0
  36. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/trainer/__init__.py +0 -0
  37. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
  38. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/trainer/diffusion_trainer.py +0 -0
  39. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/trainer/simple_trainer.py +0 -0
  40. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff/utils.py +0 -0
  41. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff.egg-info/SOURCES.txt +0 -0
  42. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff.egg-info/dependency_links.txt +0 -0
  43. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff.egg-info/requires.txt +0 -0
  44. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/flaxdiff.egg-info/top_level.txt +0 -0
  45. {flaxdiff-0.1.35.4 → flaxdiff-0.1.35.6}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.35.4
3
+ Version: 0.1.35.6
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -11,6 +11,7 @@ import einops
11
11
  import functools
12
12
  import math
13
13
  from .common import kernel_init
14
+ import jax.experimental.pallas.ops.tpu.flash_attention
14
15
 
15
16
  class EfficientAttention(nn.Module):
16
17
  """
@@ -303,27 +304,30 @@ class TransformerBlock(nn.Module):
303
304
  only_pure_attention:bool = False
304
305
  force_fp32_for_softmax: bool = True
305
306
  kernel_init: Callable = kernel_init(1.0)
307
+ norm_inputs: bool = True
308
+ explicitly_add_residual: bool = True
306
309
 
307
310
  @nn.compact
308
311
  def __call__(self, x, context=None):
309
312
  inner_dim = self.heads * self.dim_head
310
313
  C = x.shape[-1]
311
- normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
314
+ if self.norm_inputs:
315
+ x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
312
316
  if self.use_projection == True:
313
317
  if self.use_linear_attention:
314
318
  projected_x = nn.Dense(features=inner_dim,
315
319
  use_bias=False, precision=self.precision,
316
320
  kernel_init=self.kernel_init,
317
- dtype=self.dtype, name=f'project_in')(normed_x)
321
+ dtype=self.dtype, name=f'project_in')(x)
318
322
  else:
319
323
  projected_x = nn.Conv(
320
324
  features=inner_dim, kernel_size=(1, 1),
321
325
  kernel_init=self.kernel_init,
322
326
  strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
323
327
  precision=self.precision, name=f'project_in_conv',
324
- )(normed_x)
328
+ )(x)
325
329
  else:
326
- projected_x = normed_x
330
+ projected_x = x
327
331
  inner_dim = C
328
332
 
329
333
  context = projected_x if context is None else context
@@ -356,6 +360,9 @@ class TransformerBlock(nn.Module):
356
360
  strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
357
361
  precision=self.precision, name=f'project_out_conv',
358
362
  )(projected_x)
359
-
360
- out = x + projected_x
363
+
364
+ if self.only_pure_attention or self.explicitly_add_residual:
365
+ projected_x = x + projected_x
366
+
367
+ out = projected_x
361
368
  return out
@@ -50,7 +50,7 @@ class Unet(nn.Module):
50
50
  features=self.feature_depths[0],
51
51
  kernel_size=(3, 3),
52
52
  strides=(1, 1),
53
- kernel_init=self.kernel_init(1.0),
53
+ kernel_init=self.kernel_init(scale=1.0),
54
54
  dtype=self.dtype,
55
55
  precision=self.precision
56
56
  )(x)
@@ -65,7 +65,7 @@ class Unet(nn.Module):
65
65
  down_conv_type,
66
66
  name=f"down_{i}_residual_{j}",
67
67
  features=dim_in,
68
- kernel_init=self.kernel_init(1.0),
68
+ kernel_init=self.kernel_init(scale=1.0),
69
69
  kernel_size=(3, 3),
70
70
  strides=(1, 1),
71
71
  activation=self.activation,
@@ -83,7 +83,9 @@ class Unet(nn.Module):
83
83
  precision=attention_config.get("precision", self.precision),
84
84
  only_pure_attention=attention_config.get("only_pure_attention", True),
85
85
  force_fp32_for_softmax=attention_config.get("force_fp32_for_softmax", False),
86
- kernel_init=self.kernel_init(1.0),
86
+ norm_inputs=attention_config.get("norm_inputs", True),
87
+ explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
88
+ kernel_init=self.kernel_init(scale=1.0),
87
89
  name=f"down_{i}_attention_{j}")(x, textcontext)
88
90
  # print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
89
91
  downs.append(x)
@@ -106,7 +108,7 @@ class Unet(nn.Module):
106
108
  middle_conv_type,
107
109
  name=f"middle_res1_{j}",
108
110
  features=middle_dim_out,
109
- kernel_init=self.kernel_init(1.0),
111
+ kernel_init=self.kernel_init(scale=1.0),
110
112
  kernel_size=(3, 3),
111
113
  strides=(1, 1),
112
114
  activation=self.activation,
@@ -125,13 +127,15 @@ class Unet(nn.Module):
125
127
  precision=middle_attention.get("precision", self.precision),
126
128
  only_pure_attention=middle_attention.get("only_pure_attention", True),
127
129
  force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
128
- kernel_init=self.kernel_init(1.0),
130
+ norm_inputs=middle_attention.get("norm_inputs", True),
131
+ explicitly_add_residual=middle_attention.get("explicitly_add_residual", True),
132
+ kernel_init=self.kernel_init(scale=1.0),
129
133
  name=f"middle_attention_{j}")(x, textcontext)
130
134
  x = ResidualBlock(
131
135
  middle_conv_type,
132
136
  name=f"middle_res2_{j}",
133
137
  features=middle_dim_out,
134
- kernel_init=self.kernel_init(1.0),
138
+ kernel_init=self.kernel_init(scale=1.0),
135
139
  kernel_size=(3, 3),
136
140
  strides=(1, 1),
137
141
  activation=self.activation,
@@ -153,7 +157,7 @@ class Unet(nn.Module):
153
157
  up_conv_type,# if j == 0 else "separable",
154
158
  name=f"up_{i}_residual_{j}",
155
159
  features=dim_out,
156
- kernel_init=self.kernel_init(1.0),
160
+ kernel_init=self.kernel_init(scale=1.0),
157
161
  kernel_size=kernel_size,
158
162
  strides=(1, 1),
159
163
  activation=self.activation,
@@ -171,7 +175,9 @@ class Unet(nn.Module):
171
175
  precision=attention_config.get("precision", self.precision),
172
176
  only_pure_attention=attention_config.get("only_pure_attention", True),
173
177
  force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
174
- kernel_init=self.kernel_init(1.0),
178
+ norm_inputs=attention_config.get("norm_inputs", True),
179
+ explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
180
+ kernel_init=self.kernel_init(scale=1.0),
175
181
  name=f"up_{i}_attention_{j}")(x, textcontext)
176
182
  # print("Upscaling ", i, x.shape)
177
183
  if i != len(feature_depths) - 1:
@@ -190,7 +196,7 @@ class Unet(nn.Module):
190
196
  features=self.feature_depths[0],
191
197
  kernel_size=(3, 3),
192
198
  strides=(1, 1),
193
- kernel_init=self.kernel_init(1.0),
199
+ kernel_init=self.kernel_init(scale=1.0),
194
200
  dtype=self.dtype,
195
201
  precision=self.precision
196
202
  )(x)
@@ -201,7 +207,7 @@ class Unet(nn.Module):
201
207
  conv_type,
202
208
  name="final_residual",
203
209
  features=self.feature_depths[0],
204
- kernel_init=self.kernel_init(1.0),
210
+ kernel_init=self.kernel_init(scale=1.0),
205
211
  kernel_size=(3,3),
206
212
  strides=(1, 1),
207
213
  activation=self.activation,
@@ -220,7 +226,7 @@ class Unet(nn.Module):
220
226
  kernel_size=(3, 3),
221
227
  strides=(1, 1),
222
228
  # activation=jax.nn.mish
223
- kernel_init=self.kernel_init(0.0),
229
+ kernel_init=self.kernel_init(scale=0.0),
224
230
  dtype=self.dtype,
225
231
  precision=self.precision
226
232
  )(x)
@@ -69,6 +69,8 @@ class UViT(nn.Module):
69
69
  precision: PrecisionLike = None
70
70
  kernel_init: Callable = partial(kernel_init, scale=1.0)
71
71
  add_residualblock_output: bool = False
72
+ norm_inputs: bool = False
73
+ explicitly_add_residual: bool = True
72
74
 
73
75
  def setup(self):
74
76
  if self.norm_groups > 0:
@@ -110,16 +112,20 @@ class UViT(nn.Module):
110
112
  for i in range(self.num_layers // 2):
111
113
  x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
112
114
  dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
113
- use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
115
+ use_flash_attention=self.use_flash_attention, use_self_and_cross=False, force_fp32_for_softmax=self.force_fp32_for_softmax,
114
116
  only_pure_attention=False,
117
+ norm_inputs=self.norm_inputs,
118
+ explicitly_add_residual=self.explicitly_add_residual,
115
119
  kernel_init=self.kernel_init())(x)
116
120
  skips.append(x)
117
121
 
118
122
  # Middle block
119
123
  x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
120
124
  dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
121
- use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
125
+ use_flash_attention=self.use_flash_attention, use_self_and_cross=False, force_fp32_for_softmax=self.force_fp32_for_softmax,
122
126
  only_pure_attention=False,
127
+ norm_inputs=self.norm_inputs,
128
+ explicitly_add_residual=self.explicitly_add_residual,
123
129
  kernel_init=self.kernel_init())(x)
124
130
 
125
131
  # # Out blocks
@@ -131,6 +137,8 @@ class UViT(nn.Module):
131
137
  dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
132
138
  use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
133
139
  only_pure_attention=False,
140
+ norm_inputs=self.norm_inputs,
141
+ explicitly_add_residual=self.explicitly_add_residual,
134
142
  kernel_init=self.kernel_init())(x)
135
143
 
136
144
  # print(f'Shape of x after transformer blocks: {x.shape}')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.35.4
3
+ Version: 0.1.35.6
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -11,7 +11,7 @@ required_packages=[
11
11
  setup(
12
12
  name='flaxdiff',
13
13
  packages=find_packages(),
14
- version='0.1.35.4',
14
+ version='0.1.35.6',
15
15
  description='A versatile and easy to understand Diffusion library',
16
16
  long_description=open('README.md').read(),
17
17
  long_description_content_type='text/markdown',
File without changes
File without changes