flaxdiff 0.1.35.5__py3-none-any.whl → 0.1.35.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/models/attention.py +1 -0
- flaxdiff/models/simple_unet.py +11 -11
- flaxdiff/models/simple_vit.py +1 -1
- {flaxdiff-0.1.35.5.dist-info → flaxdiff-0.1.35.6.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.35.5.dist-info → flaxdiff-0.1.35.6.dist-info}/RECORD +7 -7
- {flaxdiff-0.1.35.5.dist-info → flaxdiff-0.1.35.6.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.35.5.dist-info → flaxdiff-0.1.35.6.dist-info}/top_level.txt +0 -0
flaxdiff/models/attention.py
CHANGED
flaxdiff/models/simple_unet.py
CHANGED
@@ -50,7 +50,7 @@ class Unet(nn.Module):
|
|
50
50
|
features=self.feature_depths[0],
|
51
51
|
kernel_size=(3, 3),
|
52
52
|
strides=(1, 1),
|
53
|
-
kernel_init=self.kernel_init(1.0),
|
53
|
+
kernel_init=self.kernel_init(scale=1.0),
|
54
54
|
dtype=self.dtype,
|
55
55
|
precision=self.precision
|
56
56
|
)(x)
|
@@ -65,7 +65,7 @@ class Unet(nn.Module):
|
|
65
65
|
down_conv_type,
|
66
66
|
name=f"down_{i}_residual_{j}",
|
67
67
|
features=dim_in,
|
68
|
-
kernel_init=self.kernel_init(1.0),
|
68
|
+
kernel_init=self.kernel_init(scale=1.0),
|
69
69
|
kernel_size=(3, 3),
|
70
70
|
strides=(1, 1),
|
71
71
|
activation=self.activation,
|
@@ -85,7 +85,7 @@ class Unet(nn.Module):
|
|
85
85
|
force_fp32_for_softmax=attention_config.get("force_fp32_for_softmax", False),
|
86
86
|
norm_inputs=attention_config.get("norm_inputs", True),
|
87
87
|
explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
|
88
|
-
kernel_init=self.kernel_init(1.0),
|
88
|
+
kernel_init=self.kernel_init(scale=1.0),
|
89
89
|
name=f"down_{i}_attention_{j}")(x, textcontext)
|
90
90
|
# print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
|
91
91
|
downs.append(x)
|
@@ -108,7 +108,7 @@ class Unet(nn.Module):
|
|
108
108
|
middle_conv_type,
|
109
109
|
name=f"middle_res1_{j}",
|
110
110
|
features=middle_dim_out,
|
111
|
-
kernel_init=self.kernel_init(1.0),
|
111
|
+
kernel_init=self.kernel_init(scale=1.0),
|
112
112
|
kernel_size=(3, 3),
|
113
113
|
strides=(1, 1),
|
114
114
|
activation=self.activation,
|
@@ -129,13 +129,13 @@ class Unet(nn.Module):
|
|
129
129
|
force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
|
130
130
|
norm_inputs=middle_attention.get("norm_inputs", True),
|
131
131
|
explicitly_add_residual=middle_attention.get("explicitly_add_residual", True),
|
132
|
-
kernel_init=self.kernel_init(1.0),
|
132
|
+
kernel_init=self.kernel_init(scale=1.0),
|
133
133
|
name=f"middle_attention_{j}")(x, textcontext)
|
134
134
|
x = ResidualBlock(
|
135
135
|
middle_conv_type,
|
136
136
|
name=f"middle_res2_{j}",
|
137
137
|
features=middle_dim_out,
|
138
|
-
kernel_init=self.kernel_init(1.0),
|
138
|
+
kernel_init=self.kernel_init(scale=1.0),
|
139
139
|
kernel_size=(3, 3),
|
140
140
|
strides=(1, 1),
|
141
141
|
activation=self.activation,
|
@@ -157,7 +157,7 @@ class Unet(nn.Module):
|
|
157
157
|
up_conv_type,# if j == 0 else "separable",
|
158
158
|
name=f"up_{i}_residual_{j}",
|
159
159
|
features=dim_out,
|
160
|
-
kernel_init=self.kernel_init(1.0),
|
160
|
+
kernel_init=self.kernel_init(scale=1.0),
|
161
161
|
kernel_size=kernel_size,
|
162
162
|
strides=(1, 1),
|
163
163
|
activation=self.activation,
|
@@ -177,7 +177,7 @@ class Unet(nn.Module):
|
|
177
177
|
force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
|
178
178
|
norm_inputs=attention_config.get("norm_inputs", True),
|
179
179
|
explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
|
180
|
-
kernel_init=self.kernel_init(1.0),
|
180
|
+
kernel_init=self.kernel_init(scale=1.0),
|
181
181
|
name=f"up_{i}_attention_{j}")(x, textcontext)
|
182
182
|
# print("Upscaling ", i, x.shape)
|
183
183
|
if i != len(feature_depths) - 1:
|
@@ -196,7 +196,7 @@ class Unet(nn.Module):
|
|
196
196
|
features=self.feature_depths[0],
|
197
197
|
kernel_size=(3, 3),
|
198
198
|
strides=(1, 1),
|
199
|
-
kernel_init=self.kernel_init(1.0),
|
199
|
+
kernel_init=self.kernel_init(scale=1.0),
|
200
200
|
dtype=self.dtype,
|
201
201
|
precision=self.precision
|
202
202
|
)(x)
|
@@ -207,7 +207,7 @@ class Unet(nn.Module):
|
|
207
207
|
conv_type,
|
208
208
|
name="final_residual",
|
209
209
|
features=self.feature_depths[0],
|
210
|
-
kernel_init=self.kernel_init(1.0),
|
210
|
+
kernel_init=self.kernel_init(scale=1.0),
|
211
211
|
kernel_size=(3,3),
|
212
212
|
strides=(1, 1),
|
213
213
|
activation=self.activation,
|
@@ -226,7 +226,7 @@ class Unet(nn.Module):
|
|
226
226
|
kernel_size=(3, 3),
|
227
227
|
strides=(1, 1),
|
228
228
|
# activation=jax.nn.mish
|
229
|
-
kernel_init=self.kernel_init(0.0),
|
229
|
+
kernel_init=self.kernel_init(scale=0.0),
|
230
230
|
dtype=self.dtype,
|
231
231
|
precision=self.precision
|
232
232
|
)(x)
|
flaxdiff/models/simple_vit.py
CHANGED
@@ -70,7 +70,7 @@ class UViT(nn.Module):
|
|
70
70
|
kernel_init: Callable = partial(kernel_init, scale=1.0)
|
71
71
|
add_residualblock_output: bool = False
|
72
72
|
norm_inputs: bool = False
|
73
|
-
explicitly_add_residual: bool =
|
73
|
+
explicitly_add_residual: bool = True
|
74
74
|
|
75
75
|
def setup(self):
|
76
76
|
if self.norm_groups > 0:
|
@@ -3,11 +3,11 @@ flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
|
3
3
|
flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
|
4
4
|
flaxdiff/data/online_loader.py,sha256=DoHrMZCi5gMd9tmkCpZIUU9lGxvfYtuaz58943_lCRc,11315
|
5
5
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
6
|
-
flaxdiff/models/attention.py,sha256=
|
6
|
+
flaxdiff/models/attention.py,sha256=JvrP7-09MV6IfRLRBhqjPmNUU-lkEMk9TOnJSBKcar8,13289
|
7
7
|
flaxdiff/models/common.py,sha256=hWsSs2BP2J-JN1s4qLRr-h-KYkcVyl2hOp1Wsm_L-h8,10994
|
8
8
|
flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
|
9
|
-
flaxdiff/models/simple_unet.py,sha256=
|
10
|
-
flaxdiff/models/simple_vit.py,sha256=
|
9
|
+
flaxdiff/models/simple_unet.py,sha256=L5m2j5580QP7pJ5VIme7U5xYA22PZiGP7qdvcKUnB38,11463
|
10
|
+
flaxdiff/models/simple_vit.py,sha256=UCDDr0XVnpf6tbJWKFtEt3_nAqMqOoakXf5amyVWZNo,7929
|
11
11
|
flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
|
12
12
|
flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
|
13
13
|
flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
|
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
|
|
34
34
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
35
35
|
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
36
36
|
flaxdiff/trainer/simple_trainer.py,sha256=cawm6fZNQoLLATMneAU2gQ9j7kefqHnBPHuaIj3i_a4,18237
|
37
|
-
flaxdiff-0.1.35.
|
38
|
-
flaxdiff-0.1.35.
|
39
|
-
flaxdiff-0.1.35.
|
40
|
-
flaxdiff-0.1.35.
|
37
|
+
flaxdiff-0.1.35.6.dist-info/METADATA,sha256=NVCk5V7Zc3iq-nrWTivzO17dQa1fIjYgjJb800ZrZhQ,22085
|
38
|
+
flaxdiff-0.1.35.6.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
39
|
+
flaxdiff-0.1.35.6.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
40
|
+
flaxdiff-0.1.35.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|