flaxdiff 0.1.10__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -22,7 +22,8 @@ class EfficientAttention(nn.Module):
22
22
  dtype: Optional[Dtype] = None
23
23
  precision: PrecisionLike = None
24
24
  use_bias: bool = True
25
- kernel_init: Callable = lambda : kernel_init(1.0)
25
+ kernel_init: Callable = kernel_init(1.0)
26
+ force_fp32_for_softmax: bool = True
26
27
 
27
28
  def setup(self):
28
29
  inner_dim = self.dim_head * self.heads
@@ -32,7 +33,7 @@ class EfficientAttention(nn.Module):
32
33
  self.heads * self.dim_head,
33
34
  precision=self.precision,
34
35
  use_bias=self.use_bias,
35
- kernel_init=self.kernel_init(),
36
+ kernel_init=self.kernel_init,
36
37
  dtype=self.dtype
37
38
  )
38
39
  self.query = dense(name="to_q")
@@ -40,7 +41,7 @@ class EfficientAttention(nn.Module):
40
41
  self.value = dense(name="to_v")
41
42
 
42
43
  self.proj_attn = nn.DenseGeneral(self.query_dim, use_bias=False, precision=self.precision,
43
- kernel_init=self.kernel_init(), dtype=self.dtype, name="to_out_0")
44
+ kernel_init=self.kernel_init, dtype=self.dtype, name="to_out_0")
44
45
  # self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16)
45
46
 
46
47
  def _reshape_tensor_to_head_dim(self, tensor):
@@ -113,7 +114,8 @@ class NormalAttention(nn.Module):
113
114
  dtype: Optional[Dtype] = None
114
115
  precision: PrecisionLike = None
115
116
  use_bias: bool = True
116
- kernel_init: Callable = lambda : kernel_init(1.0)
117
+ kernel_init: Callable = kernel_init(1.0)
118
+ force_fp32_for_softmax: bool = True
117
119
 
118
120
  def setup(self):
119
121
  inner_dim = self.dim_head * self.heads
@@ -123,7 +125,7 @@ class NormalAttention(nn.Module):
123
125
  axis=-1,
124
126
  precision=self.precision,
125
127
  use_bias=self.use_bias,
126
- kernel_init=self.kernel_init(),
128
+ kernel_init=self.kernel_init,
127
129
  dtype=self.dtype
128
130
  )
129
131
  self.query = dense(name="to_q")
@@ -137,7 +139,7 @@ class NormalAttention(nn.Module):
137
139
  use_bias=self.use_bias,
138
140
  dtype=self.dtype,
139
141
  name="to_out_0",
140
- kernel_init=self.kernel_init()
142
+ kernel_init=self.kernel_init
141
143
  # kernel_init=jax.nn.initializers.xavier_uniform()
142
144
  )
143
145
 
@@ -157,7 +159,7 @@ class NormalAttention(nn.Module):
157
159
 
158
160
  hidden_states = nn.dot_product_attention(
159
161
  query, key, value, dtype=self.dtype, broadcast_dropout=False,
160
- dropout_rng=None, precision=self.precision, force_fp32_for_softmax=True,
162
+ dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,
161
163
  deterministic=True
162
164
  )
163
165
  proj = self.proj_attn(hidden_states)
@@ -233,10 +235,11 @@ class BasicTransformerBlock(nn.Module):
233
235
  dtype: Optional[Dtype] = None
234
236
  precision: PrecisionLike = None
235
237
  use_bias: bool = True
236
- kernel_init: Callable = lambda : kernel_init(1.0)
238
+ kernel_init: Callable = kernel_init(1.0)
237
239
  use_flash_attention:bool = False
238
240
  use_cross_only:bool = False
239
241
  only_pure_attention:bool = False
242
+ force_fp32_for_softmax: bool = True
240
243
 
241
244
  def setup(self):
242
245
  if self.use_flash_attention:
@@ -252,7 +255,8 @@ class BasicTransformerBlock(nn.Module):
252
255
  precision=self.precision,
253
256
  use_bias=self.use_bias,
254
257
  dtype=self.dtype,
255
- kernel_init=self.kernel_init
258
+ kernel_init=self.kernel_init,
259
+ force_fp32_for_softmax=self.force_fp32_for_softmax
256
260
  )
257
261
  self.attention2 = attenBlock(
258
262
  query_dim=self.query_dim,
@@ -262,7 +266,8 @@ class BasicTransformerBlock(nn.Module):
262
266
  precision=self.precision,
263
267
  use_bias=self.use_bias,
264
268
  dtype=self.dtype,
265
- kernel_init=self.kernel_init
269
+ kernel_init=self.kernel_init,
270
+ force_fp32_for_softmax=self.force_fp32_for_softmax
266
271
  )
267
272
 
268
273
  self.ff = FlaxFeedForward(dim=self.query_dim)
@@ -296,6 +301,8 @@ class TransformerBlock(nn.Module):
296
301
  use_flash_attention:bool = False
297
302
  use_self_and_cross:bool = True
298
303
  only_pure_attention:bool = False
304
+ force_fp32_for_softmax: bool = True
305
+ kernel_init: Callable = kernel_init(1.0)
299
306
 
300
307
  @nn.compact
301
308
  def __call__(self, x, context=None):
@@ -306,12 +313,12 @@ class TransformerBlock(nn.Module):
306
313
  if self.use_linear_attention:
307
314
  projected_x = nn.Dense(features=inner_dim,
308
315
  use_bias=False, precision=self.precision,
309
- kernel_init=kernel_init(1.0),
316
+ kernel_init=self.kernel_init,
310
317
  dtype=self.dtype, name=f'project_in')(normed_x)
311
318
  else:
312
319
  projected_x = nn.Conv(
313
320
  features=inner_dim, kernel_size=(1, 1),
314
- kernel_init=kernel_init(1.0),
321
+ kernel_init=self.kernel_init,
315
322
  strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
316
323
  precision=self.precision, name=f'project_in_conv',
317
324
  )(normed_x)
@@ -331,19 +338,21 @@ class TransformerBlock(nn.Module):
331
338
  dtype=self.dtype,
332
339
  use_flash_attention=self.use_flash_attention,
333
340
  use_cross_only=(not self.use_self_and_cross),
334
- only_pure_attention=self.only_pure_attention
341
+ only_pure_attention=self.only_pure_attention,
342
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
343
+ kernel_init=self.kernel_init
335
344
  )(projected_x, context)
336
345
 
337
346
  if self.use_projection == True:
338
347
  if self.use_linear_attention:
339
348
  projected_x = nn.Dense(features=C, precision=self.precision,
340
349
  dtype=self.dtype, use_bias=False,
341
- kernel_init=kernel_init(1.0),
350
+ kernel_init=self.kernel_init,
342
351
  name=f'project_out')(projected_x)
343
352
  else:
344
353
  projected_x = nn.Conv(
345
354
  features=C, kernel_size=(1, 1),
346
- kernel_init=kernel_init(1.0),
355
+ kernel_init=self.kernel_init,
347
356
  strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
348
357
  precision=self.precision, name=f'project_out_conv',
349
358
  )(projected_x)
@@ -20,6 +20,7 @@ class Unet(nn.Module):
20
20
  dtype: Optional[Dtype] = None
21
21
  precision: PrecisionLike = None
22
22
  named_norms: bool = False # This is for backward compatibility reasons; older checkpoints have named norms
23
+ kernel_init: Callable = partial(kernel_init, dtype=jnp.float32)
23
24
 
24
25
  def setup(self):
25
26
  if self.norm_groups > 0:
@@ -49,7 +50,7 @@ class Unet(nn.Module):
49
50
  features=self.feature_depths[0],
50
51
  kernel_size=(3, 3),
51
52
  strides=(1, 1),
52
- kernel_init=kernel_init(1.0),
53
+ kernel_init=self.kernel_init(1.0),
53
54
  dtype=self.dtype,
54
55
  precision=self.precision
55
56
  )(x)
@@ -64,7 +65,7 @@ class Unet(nn.Module):
64
65
  down_conv_type,
65
66
  name=f"down_{i}_residual_{j}",
66
67
  features=dim_in,
67
- kernel_init=kernel_init(1.0),
68
+ kernel_init=self.kernel_init(1.0),
68
69
  kernel_size=(3, 3),
69
70
  strides=(1, 1),
70
71
  activation=self.activation,
@@ -81,6 +82,8 @@ class Unet(nn.Module):
81
82
  use_self_and_cross=attention_config.get("use_self_and_cross", True),
82
83
  precision=attention_config.get("precision", self.precision),
83
84
  only_pure_attention=attention_config.get("only_pure_attention", True),
85
+ force_fp32_for_softmax=attention_config.get("force_fp32_for_softmax", False),
86
+ kernel_init=self.kernel_init(1.0),
84
87
  name=f"down_{i}_attention_{j}")(x, textcontext)
85
88
  # print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
86
89
  downs.append(x)
@@ -103,7 +106,7 @@ class Unet(nn.Module):
103
106
  middle_conv_type,
104
107
  name=f"middle_res1_{j}",
105
108
  features=middle_dim_out,
106
- kernel_init=kernel_init(1.0),
109
+ kernel_init=self.kernel_init(1.0),
107
110
  kernel_size=(3, 3),
108
111
  strides=(1, 1),
109
112
  activation=self.activation,
@@ -121,12 +124,14 @@ class Unet(nn.Module):
121
124
  use_self_and_cross=False,
122
125
  precision=middle_attention.get("precision", self.precision),
123
126
  only_pure_attention=middle_attention.get("only_pure_attention", True),
127
+ force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
128
+ kernel_init=self.kernel_init(1.0),
124
129
  name=f"middle_attention_{j}")(x, textcontext)
125
130
  x = ResidualBlock(
126
131
  middle_conv_type,
127
132
  name=f"middle_res2_{j}",
128
133
  features=middle_dim_out,
129
- kernel_init=kernel_init(1.0),
134
+ kernel_init=self.kernel_init(1.0),
130
135
  kernel_size=(3, 3),
131
136
  strides=(1, 1),
132
137
  activation=self.activation,
@@ -148,7 +153,7 @@ class Unet(nn.Module):
148
153
  up_conv_type,# if j == 0 else "separable",
149
154
  name=f"up_{i}_residual_{j}",
150
155
  features=dim_out,
151
- kernel_init=kernel_init(1.0),
156
+ kernel_init=self.kernel_init(1.0),
152
157
  kernel_size=kernel_size,
153
158
  strides=(1, 1),
154
159
  activation=self.activation,
@@ -165,6 +170,8 @@ class Unet(nn.Module):
165
170
  use_self_and_cross=attention_config.get("use_self_and_cross", True),
166
171
  precision=attention_config.get("precision", self.precision),
167
172
  only_pure_attention=attention_config.get("only_pure_attention", True),
173
+ force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
174
+ kernel_init=self.kernel_init(1.0),
168
175
  name=f"up_{i}_attention_{j}")(x, textcontext)
169
176
  # print("Upscaling ", i, x.shape)
170
177
  if i != len(feature_depths) - 1:
@@ -183,7 +190,7 @@ class Unet(nn.Module):
183
190
  features=self.feature_depths[0],
184
191
  kernel_size=(3, 3),
185
192
  strides=(1, 1),
186
- kernel_init=kernel_init(1.0),
193
+ kernel_init=self.kernel_init(1.0),
187
194
  dtype=self.dtype,
188
195
  precision=self.precision
189
196
  )(x)
@@ -194,7 +201,7 @@ class Unet(nn.Module):
194
201
  conv_type,
195
202
  name="final_residual",
196
203
  features=self.feature_depths[0],
197
- kernel_init=kernel_init(1.0),
204
+ kernel_init=self.kernel_init(1.0),
198
205
  kernel_size=(3,3),
199
206
  strides=(1, 1),
200
207
  activation=self.activation,
@@ -213,7 +220,7 @@ class Unet(nn.Module):
213
220
  kernel_size=(3, 3),
214
221
  strides=(1, 1),
215
222
  # activation=jax.nn.mish
216
- kernel_init=kernel_init(0.0),
223
+ kernel_init=self.kernel_init(0.0),
217
224
  dtype=self.dtype,
218
225
  precision=self.precision
219
226
  )(x)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.10
3
+ Version: 0.1.12
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -1,10 +1,10 @@
1
1
  flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
3
3
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
4
- flaxdiff/models/attention.py,sha256=YyVI3dTAMB8cS8VWHgtIigr2YY-MYfFTlaNDfjNJOCk,12596
4
+ flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
5
5
  flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
6
6
  flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
7
- flaxdiff/models/simple_unet.py,sha256=H67Pfy8BqKHvhdw_K3lBiFdruNQFBMElw8SDZdvg9Ec,10084
7
+ flaxdiff/models/simple_unet.py,sha256=h1o9mQlLJy7Ec8Pz_O5miRbAyUaM5UNhSs-oXzpQvZo,10763
8
8
  flaxdiff/models/simple_vit.py,sha256=xD23i1b7WEvoH4tUMsLyCe9ebDcv-PpaV0Nso38Jlb8,3887
9
9
  flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
10
10
  flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
@@ -32,7 +32,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
32
32
  flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
33
33
  flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
34
34
  flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
35
- flaxdiff-0.1.10.dist-info/METADATA,sha256=q9O56jlhtuznnbmlHeKa9-gLFtWXge0bwBU6g9_P8Jk,22083
36
- flaxdiff-0.1.10.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
37
- flaxdiff-0.1.10.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
38
- flaxdiff-0.1.10.dist-info/RECORD,,
35
+ flaxdiff-0.1.12.dist-info/METADATA,sha256=s3rIj9jqh1Xr1NABqOJZw9XwyHdaLd01c_jKpzEMErQ,22083
36
+ flaxdiff-0.1.12.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
37
+ flaxdiff-0.1.12.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
38
+ flaxdiff-0.1.12.dist-info/RECORD,,