flaxdiff 0.1.29__py3-none-any.whl → 0.1.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/models/common.py +1 -1
- flaxdiff/models/simple_vit.py +10 -10
- {flaxdiff-0.1.29.dist-info → flaxdiff-0.1.30.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.29.dist-info → flaxdiff-0.1.30.dist-info}/RECORD +6 -6
- {flaxdiff-0.1.29.dist-info → flaxdiff-0.1.30.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.29.dist-info → flaxdiff-0.1.30.dist-info}/top_level.txt +0 -0
flaxdiff/models/common.py
CHANGED
@@ -8,7 +8,7 @@ import einops
|
|
8
8
|
from functools import partial
|
9
9
|
|
10
10
|
# Kernel initializer to use
|
11
|
-
def kernel_init(scale, dtype=jnp.float32):
|
11
|
+
def kernel_init(scale=1.0, dtype=jnp.float32):
|
12
12
|
scale = max(scale, 1e-10)
|
13
13
|
return nn.initializers.variance_scaling(scale=scale, mode="fan_avg", distribution="truncated_normal", dtype=dtype)
|
14
14
|
|
flaxdiff/models/simple_vit.py
CHANGED
@@ -23,7 +23,7 @@ class PatchEmbedding(nn.Module):
|
|
23
23
|
embedding_dim: int
|
24
24
|
dtype: Any = jnp.float32
|
25
25
|
precision: Any = jax.lax.Precision.HIGH
|
26
|
-
kernel_init: Callable =
|
26
|
+
kernel_init: Callable = kernel_init(1.0)
|
27
27
|
|
28
28
|
@nn.compact
|
29
29
|
def __call__(self, x):
|
@@ -34,7 +34,7 @@ class PatchEmbedding(nn.Module):
|
|
34
34
|
kernel_size=(self.patch_size, self.patch_size),
|
35
35
|
strides=(self.patch_size, self.patch_size),
|
36
36
|
dtype=self.dtype,
|
37
|
-
kernel_init=self.kernel_init
|
37
|
+
kernel_init=self.kernel_init,
|
38
38
|
precision=self.precision)(x)
|
39
39
|
x = jnp.reshape(x, (batch, -1, self.embedding_dim))
|
40
40
|
return x
|
@@ -67,7 +67,7 @@ class UViT(nn.Module):
|
|
67
67
|
norm_groups:int=8
|
68
68
|
dtype: Optional[Dtype] = None
|
69
69
|
precision: PrecisionLike = None
|
70
|
-
kernel_init: Callable = partial(kernel_init
|
70
|
+
kernel_init: Callable = partial(kernel_init)
|
71
71
|
add_residualblock_output: bool = False
|
72
72
|
|
73
73
|
def setup(self):
|
@@ -86,10 +86,10 @@ class UViT(nn.Module):
|
|
86
86
|
|
87
87
|
# Patch embedding
|
88
88
|
x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features,
|
89
|
-
dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init)(x)
|
89
|
+
dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init(1.0))(x)
|
90
90
|
num_patches = x.shape[1]
|
91
91
|
|
92
|
-
context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
|
92
|
+
context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(1.0),
|
93
93
|
dtype=self.dtype, precision=self.precision)(textcontext)
|
94
94
|
num_text_tokens = textcontext.shape[1]
|
95
95
|
|
@@ -112,7 +112,7 @@ class UViT(nn.Module):
|
|
112
112
|
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
113
113
|
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
114
114
|
only_pure_attention=False,
|
115
|
-
kernel_init=self.kernel_init())(x)
|
115
|
+
kernel_init=self.kernel_init(1.0))(x)
|
116
116
|
skips.append(x)
|
117
117
|
|
118
118
|
# Middle block
|
@@ -120,24 +120,24 @@ class UViT(nn.Module):
|
|
120
120
|
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
121
121
|
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
122
122
|
only_pure_attention=False,
|
123
|
-
kernel_init=self.kernel_init())(x)
|
123
|
+
kernel_init=self.kernel_init(1.0))(x)
|
124
124
|
|
125
125
|
# # Out blocks
|
126
126
|
for i in range(self.num_layers // 2):
|
127
127
|
x = jnp.concatenate([x, skips.pop()], axis=-1)
|
128
|
-
x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
|
128
|
+
x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(1.0),
|
129
129
|
dtype=self.dtype, precision=self.precision)(x)
|
130
130
|
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
|
131
131
|
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
132
132
|
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
133
133
|
only_pure_attention=False,
|
134
|
-
kernel_init=self.kernel_init())(x)
|
134
|
+
kernel_init=self.kernel_init(1.0))(x)
|
135
135
|
|
136
136
|
# print(f'Shape of x after transformer blocks: {x.shape}')
|
137
137
|
x = self.norm()(x)
|
138
138
|
|
139
139
|
patch_dim = self.patch_size ** 2 * self.output_channels
|
140
|
-
x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init())(x)
|
140
|
+
x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init(1.0))(x)
|
141
141
|
x = x[:, 1 + num_text_tokens:, :]
|
142
142
|
x = unpatchify(x, channels=self.output_channels)
|
143
143
|
|
@@ -4,10 +4,10 @@ flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
|
|
4
4
|
flaxdiff/data/online_loader.py,sha256=LIK_O1C3yDPvvAEOWvsJrVeBopVqjg2IOMTbiSIvH6M,11025
|
5
5
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
6
6
|
flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
|
7
|
-
flaxdiff/models/common.py,sha256=
|
7
|
+
flaxdiff/models/common.py,sha256=hWsSs2BP2J-JN1s4qLRr-h-KYkcVyl2hOp1Wsm_L-h8,10994
|
8
8
|
flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
|
9
9
|
flaxdiff/models/simple_unet.py,sha256=h1o9mQlLJy7Ec8Pz_O5miRbAyUaM5UNhSs-oXzpQvZo,10763
|
10
|
-
flaxdiff/models/simple_vit.py,sha256=
|
10
|
+
flaxdiff/models/simple_vit.py,sha256=Nnrlo5T9IUu3lu6y-SIWIgfISc07uOztBB4kyfBrQVY,7443
|
11
11
|
flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
|
12
12
|
flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
|
13
13
|
flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
|
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
|
|
34
34
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
35
35
|
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
36
36
|
flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
|
37
|
-
flaxdiff-0.1.
|
38
|
-
flaxdiff-0.1.
|
39
|
-
flaxdiff-0.1.
|
40
|
-
flaxdiff-0.1.
|
37
|
+
flaxdiff-0.1.30.dist-info/METADATA,sha256=lzEiqudjsqRLsDrI1icVnN3NM8hHrAqWloafwhxbhBE,22083
|
38
|
+
flaxdiff-0.1.30.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
|
39
|
+
flaxdiff-0.1.30.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
40
|
+
flaxdiff-0.1.30.dist-info/RECORD,,
|
File without changes
|
File without changes
|