flaxdiff 0.1.11__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -22,7 +22,7 @@ class EfficientAttention(nn.Module):
22
22
  dtype: Optional[Dtype] = None
23
23
  precision: PrecisionLike = None
24
24
  use_bias: bool = True
25
- kernel_init: Callable = lambda : kernel_init(1.0)
25
+ kernel_init: Callable = kernel_init(1.0)
26
26
  force_fp32_for_softmax: bool = True
27
27
 
28
28
  def setup(self):
@@ -33,7 +33,7 @@ class EfficientAttention(nn.Module):
33
33
  self.heads * self.dim_head,
34
34
  precision=self.precision,
35
35
  use_bias=self.use_bias,
36
- kernel_init=self.kernel_init(),
36
+ kernel_init=self.kernel_init,
37
37
  dtype=self.dtype
38
38
  )
39
39
  self.query = dense(name="to_q")
@@ -41,7 +41,7 @@ class EfficientAttention(nn.Module):
41
41
  self.value = dense(name="to_v")
42
42
 
43
43
  self.proj_attn = nn.DenseGeneral(self.query_dim, use_bias=False, precision=self.precision,
44
- kernel_init=self.kernel_init(), dtype=self.dtype, name="to_out_0")
44
+ kernel_init=self.kernel_init, dtype=self.dtype, name="to_out_0")
45
45
  # self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16)
46
46
 
47
47
  def _reshape_tensor_to_head_dim(self, tensor):
@@ -114,7 +114,7 @@ class NormalAttention(nn.Module):
114
114
  dtype: Optional[Dtype] = None
115
115
  precision: PrecisionLike = None
116
116
  use_bias: bool = True
117
- kernel_init: Callable = lambda : kernel_init(1.0)
117
+ kernel_init: Callable = kernel_init(1.0)
118
118
  force_fp32_for_softmax: bool = True
119
119
 
120
120
  def setup(self):
@@ -125,7 +125,7 @@ class NormalAttention(nn.Module):
125
125
  axis=-1,
126
126
  precision=self.precision,
127
127
  use_bias=self.use_bias,
128
- kernel_init=self.kernel_init(),
128
+ kernel_init=self.kernel_init,
129
129
  dtype=self.dtype
130
130
  )
131
131
  self.query = dense(name="to_q")
@@ -139,7 +139,7 @@ class NormalAttention(nn.Module):
139
139
  use_bias=self.use_bias,
140
140
  dtype=self.dtype,
141
141
  name="to_out_0",
142
- kernel_init=self.kernel_init()
142
+ kernel_init=self.kernel_init
143
143
  # kernel_init=jax.nn.initializers.xavier_uniform()
144
144
  )
145
145
 
@@ -235,7 +235,7 @@ class BasicTransformerBlock(nn.Module):
235
235
  dtype: Optional[Dtype] = None
236
236
  precision: PrecisionLike = None
237
237
  use_bias: bool = True
238
- kernel_init: Callable = lambda : kernel_init(1.0)
238
+ kernel_init: Callable = kernel_init(1.0)
239
239
  use_flash_attention:bool = False
240
240
  use_cross_only:bool = False
241
241
  only_pure_attention:bool = False
@@ -302,7 +302,7 @@ class TransformerBlock(nn.Module):
302
302
  use_self_and_cross:bool = True
303
303
  only_pure_attention:bool = False
304
304
  force_fp32_for_softmax: bool = True
305
- kernel_init: Callable = lambda : kernel_init(1.0)
305
+ kernel_init: Callable = kernel_init(1.0)
306
306
 
307
307
  @nn.compact
308
308
  def __call__(self, x, context=None):
@@ -313,12 +313,12 @@ class TransformerBlock(nn.Module):
313
313
  if self.use_linear_attention:
314
314
  projected_x = nn.Dense(features=inner_dim,
315
315
  use_bias=False, precision=self.precision,
316
- kernel_init=self.kernel_init(),
316
+ kernel_init=self.kernel_init,
317
317
  dtype=self.dtype, name=f'project_in')(normed_x)
318
318
  else:
319
319
  projected_x = nn.Conv(
320
320
  features=inner_dim, kernel_size=(1, 1),
321
- kernel_init=self.kernel_init(),
321
+ kernel_init=self.kernel_init,
322
322
  strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
323
323
  precision=self.precision, name=f'project_in_conv',
324
324
  )(normed_x)
@@ -347,12 +347,12 @@ class TransformerBlock(nn.Module):
347
347
  if self.use_linear_attention:
348
348
  projected_x = nn.Dense(features=C, precision=self.precision,
349
349
  dtype=self.dtype, use_bias=False,
350
- kernel_init=self.kernel_init(),
350
+ kernel_init=self.kernel_init,
351
351
  name=f'project_out')(projected_x)
352
352
  else:
353
353
  projected_x = nn.Conv(
354
354
  features=C, kernel_size=(1, 1),
355
- kernel_init=self.kernel_init(),
355
+ kernel_init=self.kernel_init,
356
356
  strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
357
357
  precision=self.precision, name=f'project_out_conv',
358
358
  )(projected_x)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.11
3
+ Version: 0.1.12
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -1,7 +1,7 @@
1
1
  flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
3
3
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
4
- flaxdiff/models/attention.py,sha256=X1n-50sqDYkmGZsFUHjc04_dxusX-FirrK3PKofcyXo,13075
4
+ flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
5
5
  flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
6
6
  flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
7
7
  flaxdiff/models/simple_unet.py,sha256=h1o9mQlLJy7Ec8Pz_O5miRbAyUaM5UNhSs-oXzpQvZo,10763
@@ -32,7 +32,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
32
32
  flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
33
33
  flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
34
34
  flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
35
- flaxdiff-0.1.11.dist-info/METADATA,sha256=LR2dU9075s_yEAGkgIpMNmXOYYqcXvyZ-YeVenjcDiI,22083
36
- flaxdiff-0.1.11.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
37
- flaxdiff-0.1.11.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
38
- flaxdiff-0.1.11.dist-info/RECORD,,
35
+ flaxdiff-0.1.12.dist-info/METADATA,sha256=s3rIj9jqh1Xr1NABqOJZw9XwyHdaLd01c_jKpzEMErQ,22083
36
+ flaxdiff-0.1.12.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
37
+ flaxdiff-0.1.12.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
38
+ flaxdiff-0.1.12.dist-info/RECORD,,