flaxdiff 0.1.29__py3-none-any.whl → 0.1.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flaxdiff/models/common.py CHANGED
@@ -8,7 +8,7 @@ import einops
8
8
  from functools import partial
9
9
 
10
10
  # Kernel initializer to use
11
- def kernel_init(scale, dtype=jnp.float32):
11
+ def kernel_init(scale=1.0, dtype=jnp.float32):
12
12
  scale = max(scale, 1e-10)
13
13
  return nn.initializers.variance_scaling(scale=scale, mode="fan_avg", distribution="truncated_normal", dtype=dtype)
14
14
 
@@ -67,7 +67,7 @@ class UViT(nn.Module):
67
67
  norm_groups:int=8
68
68
  dtype: Optional[Dtype] = None
69
69
  precision: PrecisionLike = None
70
- kernel_init: Callable = partial(kernel_init, 1.0)
70
+ kernel_init: Callable = partial(kernel_init, scale=1.0)
71
71
  add_residualblock_output: bool = False
72
72
 
73
73
  def setup(self):
@@ -151,7 +151,7 @@ class UViT(nn.Module):
151
151
  kernel_size=(3, 3),
152
152
  strides=(1, 1),
153
153
  # activation=jax.nn.mish
154
- kernel_init=self.kernel_init(0.0),
154
+ kernel_init=self.kernel_init(scale=0.0),
155
155
  dtype=self.dtype,
156
156
  precision=self.precision
157
157
  )(x)
@@ -165,7 +165,7 @@ class UViT(nn.Module):
165
165
  kernel_size=(3, 3),
166
166
  strides=(1, 1),
167
167
  # activation=jax.nn.mish
168
- kernel_init=self.kernel_init(0.0),
168
+ kernel_init=self.kernel_init(scale=0.0),
169
169
  dtype=self.dtype,
170
170
  precision=self.precision
171
171
  )(x)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.29
3
+ Version: 0.1.31
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -4,10 +4,10 @@ flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
4
4
  flaxdiff/data/online_loader.py,sha256=LIK_O1C3yDPvvAEOWvsJrVeBopVqjg2IOMTbiSIvH6M,11025
5
5
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
6
6
  flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
7
- flaxdiff/models/common.py,sha256=fw_gP7PZayO6RVe6xSf-7FtVq-S0pp5U6NgHg4PlKO8,10990
7
+ flaxdiff/models/common.py,sha256=hWsSs2BP2J-JN1s4qLRr-h-KYkcVyl2hOp1Wsm_L-h8,10994
8
8
  flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
9
9
  flaxdiff/models/simple_unet.py,sha256=h1o9mQlLJy7Ec8Pz_O5miRbAyUaM5UNhSs-oXzpQvZo,10763
10
- flaxdiff/models/simple_vit.py,sha256=atjeXc22w8WYub_6d0JAFFgvQ4TP1wt4N1ubIzZlQZ0,7436
10
+ flaxdiff/models/simple_vit.py,sha256=NHt6v-teGjiI65fk1l1WN3WqfeqTE7xY9VYqBiYUDgI,7454
11
11
  flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
12
12
  flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
13
13
  flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
34
34
  flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
35
35
  flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
36
36
  flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
37
- flaxdiff-0.1.29.dist-info/METADATA,sha256=PcevgEjt61-62ccMC_CI4EvHYUX-tdrpEBptKXkTudA,22083
38
- flaxdiff-0.1.29.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
39
- flaxdiff-0.1.29.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
- flaxdiff-0.1.29.dist-info/RECORD,,
37
+ flaxdiff-0.1.31.dist-info/METADATA,sha256=Vh-cTPdUEyYBY-SFA7GxIVXucEOlDqDqkGULQyw6TIM,22083
38
+ flaxdiff-0.1.31.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
39
+ flaxdiff-0.1.31.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
+ flaxdiff-0.1.31.dist-info/RECORD,,