flaxdiff 0.1.29__py3-none-any.whl → 0.1.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/models/common.py +1 -1
- flaxdiff/models/simple_vit.py +3 -3
- {flaxdiff-0.1.29.dist-info → flaxdiff-0.1.31.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.29.dist-info → flaxdiff-0.1.31.dist-info}/RECORD +6 -6
- {flaxdiff-0.1.29.dist-info → flaxdiff-0.1.31.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.29.dist-info → flaxdiff-0.1.31.dist-info}/top_level.txt +0 -0
flaxdiff/models/common.py
CHANGED
@@ -8,7 +8,7 @@ import einops
|
|
8
8
|
from functools import partial
|
9
9
|
|
10
10
|
# Kernel initializer to use
|
11
|
-
def kernel_init(scale, dtype=jnp.float32):
|
11
|
+
def kernel_init(scale=1.0, dtype=jnp.float32):
|
12
12
|
scale = max(scale, 1e-10)
|
13
13
|
return nn.initializers.variance_scaling(scale=scale, mode="fan_avg", distribution="truncated_normal", dtype=dtype)
|
14
14
|
|
flaxdiff/models/simple_vit.py
CHANGED
@@ -67,7 +67,7 @@ class UViT(nn.Module):
|
|
67
67
|
norm_groups:int=8
|
68
68
|
dtype: Optional[Dtype] = None
|
69
69
|
precision: PrecisionLike = None
|
70
|
-
kernel_init: Callable = partial(kernel_init, 1.0)
|
70
|
+
kernel_init: Callable = partial(kernel_init, scale=1.0)
|
71
71
|
add_residualblock_output: bool = False
|
72
72
|
|
73
73
|
def setup(self):
|
@@ -151,7 +151,7 @@ class UViT(nn.Module):
|
|
151
151
|
kernel_size=(3, 3),
|
152
152
|
strides=(1, 1),
|
153
153
|
# activation=jax.nn.mish
|
154
|
-
kernel_init=self.kernel_init(0.0),
|
154
|
+
kernel_init=self.kernel_init(scale=0.0),
|
155
155
|
dtype=self.dtype,
|
156
156
|
precision=self.precision
|
157
157
|
)(x)
|
@@ -165,7 +165,7 @@ class UViT(nn.Module):
|
|
165
165
|
kernel_size=(3, 3),
|
166
166
|
strides=(1, 1),
|
167
167
|
# activation=jax.nn.mish
|
168
|
-
kernel_init=self.kernel_init(0.0),
|
168
|
+
kernel_init=self.kernel_init(scale=0.0),
|
169
169
|
dtype=self.dtype,
|
170
170
|
precision=self.precision
|
171
171
|
)(x)
|
@@ -4,10 +4,10 @@ flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
|
|
4
4
|
flaxdiff/data/online_loader.py,sha256=LIK_O1C3yDPvvAEOWvsJrVeBopVqjg2IOMTbiSIvH6M,11025
|
5
5
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
6
6
|
flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
|
7
|
-
flaxdiff/models/common.py,sha256=
|
7
|
+
flaxdiff/models/common.py,sha256=hWsSs2BP2J-JN1s4qLRr-h-KYkcVyl2hOp1Wsm_L-h8,10994
|
8
8
|
flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
|
9
9
|
flaxdiff/models/simple_unet.py,sha256=h1o9mQlLJy7Ec8Pz_O5miRbAyUaM5UNhSs-oXzpQvZo,10763
|
10
|
-
flaxdiff/models/simple_vit.py,sha256=
|
10
|
+
flaxdiff/models/simple_vit.py,sha256=NHt6v-teGjiI65fk1l1WN3WqfeqTE7xY9VYqBiYUDgI,7454
|
11
11
|
flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
|
12
12
|
flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
|
13
13
|
flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
|
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
|
|
34
34
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
35
35
|
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
36
36
|
flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
|
37
|
-
flaxdiff-0.1.
|
38
|
-
flaxdiff-0.1.
|
39
|
-
flaxdiff-0.1.
|
40
|
-
flaxdiff-0.1.
|
37
|
+
flaxdiff-0.1.31.dist-info/METADATA,sha256=Vh-cTPdUEyYBY-SFA7GxIVXucEOlDqDqkGULQyw6TIM,22083
|
38
|
+
flaxdiff-0.1.31.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
|
39
|
+
flaxdiff-0.1.31.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
40
|
+
flaxdiff-0.1.31.dist-info/RECORD,,
|
File without changes
|
File without changes
|