flaxdiff 0.1.28__py3-none-any.whl → 0.1.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/models/simple_vit.py +6 -13
- {flaxdiff-0.1.28.dist-info → flaxdiff-0.1.29.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.28.dist-info → flaxdiff-0.1.29.dist-info}/RECORD +5 -5
- {flaxdiff-0.1.28.dist-info → flaxdiff-0.1.29.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.28.dist-info → flaxdiff-0.1.29.dist-info}/top_level.txt +0 -0
flaxdiff/models/simple_vit.py
CHANGED
@@ -136,32 +136,25 @@ class UViT(nn.Module):
|
|
136
136
|
# print(f'Shape of x after transformer blocks: {x.shape}')
|
137
137
|
x = self.norm()(x)
|
138
138
|
|
139
|
-
# print(f'Shape of x after norm: {x.shape}')
|
140
|
-
|
141
139
|
patch_dim = self.patch_size ** 2 * self.output_channels
|
142
140
|
x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init())(x)
|
143
|
-
# print(f'Shape of x after patch dense layer: {x.shape}, patch_dim: {patch_dim}')
|
144
141
|
x = x[:, 1 + num_text_tokens:, :]
|
145
142
|
x = unpatchify(x, channels=self.output_channels)
|
146
|
-
# print(f'Shape of x after final dense layer: {x.shape}')
|
147
143
|
|
148
144
|
if self.add_residualblock_output:
|
149
145
|
# Concatenate the original image
|
150
146
|
x = jnp.concatenate([original_img, x], axis=-1)
|
151
147
|
|
152
|
-
x =
|
148
|
+
x = ConvLayer(
|
153
149
|
"conv",
|
154
|
-
name="final_residual",
|
155
150
|
features=64,
|
156
|
-
|
157
|
-
kernel_size=(3,3),
|
151
|
+
kernel_size=(3, 3),
|
158
152
|
strides=(1, 1),
|
159
|
-
activation=
|
160
|
-
|
153
|
+
# activation=jax.nn.mish
|
154
|
+
kernel_init=self.kernel_init(0.0),
|
161
155
|
dtype=self.dtype,
|
162
|
-
precision=self.precision
|
163
|
-
|
164
|
-
)(x, temb)
|
156
|
+
precision=self.precision
|
157
|
+
)(x)
|
165
158
|
|
166
159
|
x = self.norm()(x)
|
167
160
|
x = self.activation(x)
|
@@ -7,7 +7,7 @@ flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,
|
|
7
7
|
flaxdiff/models/common.py,sha256=fw_gP7PZayO6RVe6xSf-7FtVq-S0pp5U6NgHg4PlKO8,10990
|
8
8
|
flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
|
9
9
|
flaxdiff/models/simple_unet.py,sha256=h1o9mQlLJy7Ec8Pz_O5miRbAyUaM5UNhSs-oXzpQvZo,10763
|
10
|
-
flaxdiff/models/simple_vit.py,sha256=
|
10
|
+
flaxdiff/models/simple_vit.py,sha256=atjeXc22w8WYub_6d0JAFFgvQ4TP1wt4N1ubIzZlQZ0,7436
|
11
11
|
flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
|
12
12
|
flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
|
13
13
|
flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
|
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
|
|
34
34
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
35
35
|
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
36
36
|
flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
|
37
|
-
flaxdiff-0.1.
|
38
|
-
flaxdiff-0.1.
|
39
|
-
flaxdiff-0.1.
|
40
|
-
flaxdiff-0.1.
|
37
|
+
flaxdiff-0.1.29.dist-info/METADATA,sha256=PcevgEjt61-62ccMC_CI4EvHYUX-tdrpEBptKXkTudA,22083
|
38
|
+
flaxdiff-0.1.29.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
|
39
|
+
flaxdiff-0.1.29.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
40
|
+
flaxdiff-0.1.29.dist-info/RECORD,,
|
File without changes
|
File without changes
|