flaxdiff 0.1.37.7__tar.gz → 0.1.38.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/PKG-INFO +1 -1
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/data/sources/tfds.py +7 -7
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/models/attention.py +22 -16
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/models/autoencoder/diffusers.py +4 -4
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/models/common.py +8 -18
- flaxdiff-0.1.38.1/flaxdiff/models/general.py +21 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/models/simple_unet.py +1 -12
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/models/simple_vit.py +8 -12
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/trainer/diffusion_trainer.py +2 -2
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/trainer/simple_trainer.py +24 -6
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff.egg-info/PKG-INFO +1 -1
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff.egg-info/SOURCES.txt +1 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/pyproject.toml +1 -1
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/README.md +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/__init__.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/data/__init__.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/data/dataset_map.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/data/datasets.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/data/online_loader.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/data/sources/gcs.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/metrics/inception.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/metrics/psnr.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/metrics/ssim.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/metrics/utils.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/models/__init__.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/models/autoencoder/__init__.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/models/favor_fastattn.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/predictors/__init__.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/samplers/__init__.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/samplers/common.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/samplers/ddim.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/samplers/ddpm.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/samplers/euler.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/samplers/heun_sampler.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/samplers/multistep_dpm.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/samplers/rk4_sampler.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/schedulers/__init__.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/schedulers/common.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/schedulers/continuous.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/schedulers/cosine.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/schedulers/discrete.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/schedulers/exp.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/schedulers/karras.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/schedulers/linear.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/schedulers/sqrt.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/trainer/__init__.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/trainer/video_diffusion_trainer.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff/utils.py +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff.egg-info/dependency_links.txt +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff.egg-info/requires.txt +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/flaxdiff.egg-info/top_level.txt +0 -0
- {flaxdiff-0.1.37.7 → flaxdiff-0.1.38.1}/setup.cfg +0 -0
@@ -50,13 +50,12 @@ def tfds_augmenters(image_scale, method):
|
|
50
50
|
else:
|
51
51
|
interpolation = cv2.INTER_AREA
|
52
52
|
|
53
|
-
|
54
|
-
augmax.HorizontalFlip(0.5),
|
55
|
-
augmax.RandomContrast((-0.05, 0.05), 1.),
|
56
|
-
augmax.RandomBrightness((-0.2, 0.2), 1.)
|
57
|
-
)
|
53
|
+
from torchvision.transforms import v2
|
58
54
|
|
59
|
-
augments =
|
55
|
+
augments = v2.Compose([
|
56
|
+
v2.RandomHorizontalFlip(p=0.5),
|
57
|
+
v2.ColorJitter(brightness=0.2, contrast=0.05, saturation=0.2)
|
58
|
+
])
|
60
59
|
|
61
60
|
class augmenters(pygrain.MapTransform):
|
62
61
|
def __init__(self, *args, **kwargs):
|
@@ -67,8 +66,9 @@ def tfds_augmenters(image_scale, method):
|
|
67
66
|
image = element['image']
|
68
67
|
image = cv2.resize(image, (image_scale, image_scale),
|
69
68
|
interpolation=interpolation)
|
70
|
-
|
69
|
+
image = augments(image)
|
71
70
|
# image = (image - 127.5) / 127.5
|
71
|
+
|
72
72
|
caption = labelizer(element)
|
73
73
|
results = self.tokenize(caption)
|
74
74
|
return {
|
@@ -23,7 +23,7 @@ class EfficientAttention(nn.Module):
|
|
23
23
|
dtype: Optional[Dtype] = None
|
24
24
|
precision: PrecisionLike = None
|
25
25
|
use_bias: bool = True
|
26
|
-
kernel_init: Callable = kernel_init(1.0)
|
26
|
+
# kernel_init: Callable = kernel_init(1.0)
|
27
27
|
force_fp32_for_softmax: bool = True
|
28
28
|
|
29
29
|
def setup(self):
|
@@ -34,15 +34,21 @@ class EfficientAttention(nn.Module):
|
|
34
34
|
self.heads * self.dim_head,
|
35
35
|
precision=self.precision,
|
36
36
|
use_bias=self.use_bias,
|
37
|
-
kernel_init=self.kernel_init,
|
37
|
+
# kernel_init=self.kernel_init,
|
38
38
|
dtype=self.dtype
|
39
39
|
)
|
40
40
|
self.query = dense(name="to_q")
|
41
41
|
self.key = dense(name="to_k")
|
42
42
|
self.value = dense(name="to_v")
|
43
43
|
|
44
|
-
self.proj_attn = nn.DenseGeneral(
|
45
|
-
|
44
|
+
self.proj_attn = nn.DenseGeneral(
|
45
|
+
self.query_dim,
|
46
|
+
use_bias=False,
|
47
|
+
precision=self.precision,
|
48
|
+
# kernel_init=self.kernel_init,
|
49
|
+
dtype=self.dtype,
|
50
|
+
name="to_out_0"
|
51
|
+
)
|
46
52
|
# self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16)
|
47
53
|
|
48
54
|
def _reshape_tensor_to_head_dim(self, tensor):
|
@@ -115,7 +121,7 @@ class NormalAttention(nn.Module):
|
|
115
121
|
dtype: Optional[Dtype] = None
|
116
122
|
precision: PrecisionLike = None
|
117
123
|
use_bias: bool = True
|
118
|
-
kernel_init: Callable = kernel_init(1.0)
|
124
|
+
# kernel_init: Callable = kernel_init(1.0)
|
119
125
|
force_fp32_for_softmax: bool = True
|
120
126
|
|
121
127
|
def setup(self):
|
@@ -126,7 +132,7 @@ class NormalAttention(nn.Module):
|
|
126
132
|
axis=-1,
|
127
133
|
precision=self.precision,
|
128
134
|
use_bias=self.use_bias,
|
129
|
-
kernel_init=self.kernel_init,
|
135
|
+
# kernel_init=self.kernel_init,
|
130
136
|
dtype=self.dtype
|
131
137
|
)
|
132
138
|
self.query = dense(name="to_q")
|
@@ -140,7 +146,7 @@ class NormalAttention(nn.Module):
|
|
140
146
|
use_bias=self.use_bias,
|
141
147
|
dtype=self.dtype,
|
142
148
|
name="to_out_0",
|
143
|
-
kernel_init=self.kernel_init
|
149
|
+
# kernel_init=self.kernel_init
|
144
150
|
# kernel_init=jax.nn.initializers.xavier_uniform()
|
145
151
|
)
|
146
152
|
|
@@ -236,7 +242,7 @@ class BasicTransformerBlock(nn.Module):
|
|
236
242
|
dtype: Optional[Dtype] = None
|
237
243
|
precision: PrecisionLike = None
|
238
244
|
use_bias: bool = True
|
239
|
-
kernel_init: Callable = kernel_init(1.0)
|
245
|
+
# kernel_init: Callable = kernel_init(1.0)
|
240
246
|
use_flash_attention:bool = False
|
241
247
|
use_cross_only:bool = False
|
242
248
|
only_pure_attention:bool = False
|
@@ -256,7 +262,7 @@ class BasicTransformerBlock(nn.Module):
|
|
256
262
|
precision=self.precision,
|
257
263
|
use_bias=self.use_bias,
|
258
264
|
dtype=self.dtype,
|
259
|
-
kernel_init=self.kernel_init,
|
265
|
+
# kernel_init=self.kernel_init,
|
260
266
|
force_fp32_for_softmax=self.force_fp32_for_softmax
|
261
267
|
)
|
262
268
|
self.attention2 = attenBlock(
|
@@ -267,7 +273,7 @@ class BasicTransformerBlock(nn.Module):
|
|
267
273
|
precision=self.precision,
|
268
274
|
use_bias=self.use_bias,
|
269
275
|
dtype=self.dtype,
|
270
|
-
kernel_init=self.kernel_init,
|
276
|
+
# kernel_init=self.kernel_init,
|
271
277
|
force_fp32_for_softmax=self.force_fp32_for_softmax
|
272
278
|
)
|
273
279
|
|
@@ -303,7 +309,7 @@ class TransformerBlock(nn.Module):
|
|
303
309
|
use_self_and_cross:bool = True
|
304
310
|
only_pure_attention:bool = False
|
305
311
|
force_fp32_for_softmax: bool = True
|
306
|
-
kernel_init: Callable = kernel_init(1.0)
|
312
|
+
# kernel_init: Callable = kernel_init(1.0)
|
307
313
|
norm_inputs: bool = True
|
308
314
|
explicitly_add_residual: bool = True
|
309
315
|
|
@@ -317,12 +323,12 @@ class TransformerBlock(nn.Module):
|
|
317
323
|
if self.use_linear_attention:
|
318
324
|
projected_x = nn.Dense(features=inner_dim,
|
319
325
|
use_bias=False, precision=self.precision,
|
320
|
-
|
326
|
+
# kernel_init=self.kernel_init,
|
321
327
|
dtype=self.dtype, name=f'project_in')(x)
|
322
328
|
else:
|
323
329
|
projected_x = nn.Conv(
|
324
330
|
features=inner_dim, kernel_size=(1, 1),
|
325
|
-
kernel_init=self.kernel_init,
|
331
|
+
# kernel_init=self.kernel_init,
|
326
332
|
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
327
333
|
precision=self.precision, name=f'project_in_conv',
|
328
334
|
)(x)
|
@@ -344,19 +350,19 @@ class TransformerBlock(nn.Module):
|
|
344
350
|
use_cross_only=(not self.use_self_and_cross),
|
345
351
|
only_pure_attention=self.only_pure_attention,
|
346
352
|
force_fp32_for_softmax=self.force_fp32_for_softmax,
|
347
|
-
kernel_init=self.kernel_init
|
353
|
+
# kernel_init=self.kernel_init
|
348
354
|
)(projected_x, context)
|
349
355
|
|
350
356
|
if self.use_projection == True:
|
351
357
|
if self.use_linear_attention:
|
352
358
|
projected_x = nn.Dense(features=C, precision=self.precision,
|
353
359
|
dtype=self.dtype, use_bias=False,
|
354
|
-
|
360
|
+
# kernel_init=self.kernel_init,
|
355
361
|
name=f'project_out')(projected_x)
|
356
362
|
else:
|
357
363
|
projected_x = nn.Conv(
|
358
364
|
features=C, kernel_size=(1, 1),
|
359
|
-
kernel_init=self.
|
365
|
+
# kernel_init=self.kernel_i nit,
|
360
366
|
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
361
367
|
precision=self.precision, name=f'project_out_conv',
|
362
368
|
)(projected_x)
|
@@ -14,15 +14,15 @@ class StableDiffusionVAE(AutoEncoder):
|
|
14
14
|
def __init__(self, modelname = "CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16):
|
15
15
|
|
16
16
|
from diffusers.models.vae_flax import FlaxEncoder, FlaxDecoder
|
17
|
-
from diffusers import FlaxStableDiffusionPipeline
|
17
|
+
from diffusers import FlaxStableDiffusionPipeline, FlaxAutoencoderKL
|
18
18
|
|
19
|
-
|
19
|
+
vae, params = FlaxAutoencoderKL.from_pretrained(
|
20
20
|
modelname,
|
21
|
-
revision=revision,
|
21
|
+
# revision=revision,
|
22
22
|
dtype=dtype,
|
23
23
|
)
|
24
24
|
|
25
|
-
vae = pipeline.vae
|
25
|
+
# vae = pipeline.vae
|
26
26
|
|
27
27
|
enc = FlaxEncoder(
|
28
28
|
in_channels=vae.config.in_channels,
|
@@ -108,13 +108,16 @@ class FourierEmbedding(nn.Module):
|
|
108
108
|
class TimeProjection(nn.Module):
|
109
109
|
features:int
|
110
110
|
activation:Callable=jax.nn.gelu
|
111
|
-
kernel_init:Callable=kernel_init(1.0)
|
112
111
|
|
113
112
|
@nn.compact
|
114
113
|
def __call__(self, x):
|
115
|
-
x = nn.DenseGeneral(
|
114
|
+
x = nn.DenseGeneral(
|
115
|
+
self.features,
|
116
|
+
)(x)
|
116
117
|
x = self.activation(x)
|
117
|
-
x = nn.DenseGeneral(
|
118
|
+
x = nn.DenseGeneral(
|
119
|
+
self.features,
|
120
|
+
)(x)
|
118
121
|
x = self.activation(x)
|
119
122
|
return x
|
120
123
|
|
@@ -123,7 +126,6 @@ class SeparableConv(nn.Module):
|
|
123
126
|
kernel_size:tuple=(3, 3)
|
124
127
|
strides:tuple=(1, 1)
|
125
128
|
use_bias:bool=False
|
126
|
-
kernel_init:Callable=kernel_init(1.0)
|
127
129
|
padding:str="SAME"
|
128
130
|
dtype: Optional[Dtype] = None
|
129
131
|
precision: PrecisionLike = None
|
@@ -133,7 +135,7 @@ class SeparableConv(nn.Module):
|
|
133
135
|
in_features = x.shape[-1]
|
134
136
|
depthwise = nn.Conv(
|
135
137
|
features=in_features, kernel_size=self.kernel_size,
|
136
|
-
strides=self.strides,
|
138
|
+
strides=self.strides,
|
137
139
|
feature_group_count=in_features, use_bias=self.use_bias,
|
138
140
|
padding=self.padding,
|
139
141
|
dtype=self.dtype,
|
@@ -141,7 +143,7 @@ class SeparableConv(nn.Module):
|
|
141
143
|
)(x)
|
142
144
|
pointwise = nn.Conv(
|
143
145
|
features=self.features, kernel_size=(1, 1),
|
144
|
-
strides=(1, 1),
|
146
|
+
strides=(1, 1),
|
145
147
|
use_bias=self.use_bias,
|
146
148
|
dtype=self.dtype,
|
147
149
|
precision=self.precision
|
@@ -153,7 +155,6 @@ class ConvLayer(nn.Module):
|
|
153
155
|
features:int
|
154
156
|
kernel_size:tuple=(3, 3)
|
155
157
|
strides:tuple=(1, 1)
|
156
|
-
kernel_init:Callable=kernel_init(1.0)
|
157
158
|
dtype: Optional[Dtype] = None
|
158
159
|
precision: PrecisionLike = None
|
159
160
|
|
@@ -164,7 +165,6 @@ class ConvLayer(nn.Module):
|
|
164
165
|
features=self.features,
|
165
166
|
kernel_size=self.kernel_size,
|
166
167
|
strides=self.strides,
|
167
|
-
kernel_init=self.kernel_init,
|
168
168
|
dtype=self.dtype,
|
169
169
|
precision=self.precision
|
170
170
|
)
|
@@ -183,7 +183,6 @@ class ConvLayer(nn.Module):
|
|
183
183
|
features=self.features,
|
184
184
|
kernel_size=self.kernel_size,
|
185
185
|
strides=self.strides,
|
186
|
-
kernel_init=self.kernel_init,
|
187
186
|
dtype=self.dtype,
|
188
187
|
precision=self.precision
|
189
188
|
)
|
@@ -192,7 +191,6 @@ class ConvLayer(nn.Module):
|
|
192
191
|
features=self.features,
|
193
192
|
kernel_size=self.kernel_size,
|
194
193
|
strides=self.strides,
|
195
|
-
kernel_init=self.kernel_init,
|
196
194
|
dtype=self.dtype,
|
197
195
|
precision=self.precision
|
198
196
|
)
|
@@ -206,7 +204,6 @@ class Upsample(nn.Module):
|
|
206
204
|
activation:Callable=jax.nn.swish
|
207
205
|
dtype: Optional[Dtype] = None
|
208
206
|
precision: PrecisionLike = None
|
209
|
-
kernel_init:Callable=kernel_init(1.0)
|
210
207
|
|
211
208
|
@nn.compact
|
212
209
|
def __call__(self, x, residual=None):
|
@@ -221,7 +218,6 @@ class Upsample(nn.Module):
|
|
221
218
|
strides=(1, 1),
|
222
219
|
dtype=self.dtype,
|
223
220
|
precision=self.precision,
|
224
|
-
kernel_init=self.kernel_init
|
225
221
|
)(out)
|
226
222
|
if residual is not None:
|
227
223
|
out = jnp.concatenate([out, residual], axis=-1)
|
@@ -233,7 +229,6 @@ class Downsample(nn.Module):
|
|
233
229
|
activation:Callable=jax.nn.swish
|
234
230
|
dtype: Optional[Dtype] = None
|
235
231
|
precision: PrecisionLike = None
|
236
|
-
kernel_init:Callable=kernel_init(1.0)
|
237
232
|
|
238
233
|
@nn.compact
|
239
234
|
def __call__(self, x, residual=None):
|
@@ -244,7 +239,6 @@ class Downsample(nn.Module):
|
|
244
239
|
strides=(2, 2),
|
245
240
|
dtype=self.dtype,
|
246
241
|
precision=self.precision,
|
247
|
-
kernel_init=self.kernel_init
|
248
242
|
)(x)
|
249
243
|
if residual is not None:
|
250
244
|
if residual.shape[1] > out.shape[1]:
|
@@ -269,7 +263,6 @@ class ResidualBlock(nn.Module):
|
|
269
263
|
direction:str=None
|
270
264
|
res:int=2
|
271
265
|
norm_groups:int=8
|
272
|
-
kernel_init:Callable=kernel_init(1.0)
|
273
266
|
dtype: Optional[Dtype] = None
|
274
267
|
precision: PrecisionLike = None
|
275
268
|
named_norms:bool=False
|
@@ -296,7 +289,6 @@ class ResidualBlock(nn.Module):
|
|
296
289
|
features=self.features,
|
297
290
|
kernel_size=self.kernel_size,
|
298
291
|
strides=self.strides,
|
299
|
-
kernel_init=self.kernel_init,
|
300
292
|
name="conv1",
|
301
293
|
dtype=self.dtype,
|
302
294
|
precision=self.precision
|
@@ -321,7 +313,6 @@ class ResidualBlock(nn.Module):
|
|
321
313
|
features=self.features,
|
322
314
|
kernel_size=self.kernel_size,
|
323
315
|
strides=self.strides,
|
324
|
-
kernel_init=self.kernel_init,
|
325
316
|
name="conv2",
|
326
317
|
dtype=self.dtype,
|
327
318
|
precision=self.precision
|
@@ -333,7 +324,6 @@ class ResidualBlock(nn.Module):
|
|
333
324
|
features=self.features,
|
334
325
|
kernel_size=(1, 1),
|
335
326
|
strides=1,
|
336
|
-
kernel_init=self.kernel_init,
|
337
327
|
name="residual_conv",
|
338
328
|
dtype=self.dtype,
|
339
329
|
precision=self.precision
|
@@ -0,0 +1,21 @@
|
|
1
|
+
from flax import linen as nn
|
2
|
+
import jax
|
3
|
+
import jax.numpy as jnp
|
4
|
+
|
5
|
+
class BCHWModelWrapper(nn.Module):
|
6
|
+
model: nn.Module
|
7
|
+
|
8
|
+
@nn.compact
|
9
|
+
def __call__(self, x, temb, textcontext):
|
10
|
+
# Reshape the input to BCHW format from BHWC
|
11
|
+
x = jnp.transpose(x, (0, 3, 1, 2))
|
12
|
+
# Pass the input through the UNet model
|
13
|
+
out = self.model(
|
14
|
+
sample=x,
|
15
|
+
timesteps=temb,
|
16
|
+
encoder_hidden_states=textcontext,
|
17
|
+
)
|
18
|
+
# Reshape the output back to BHWC format
|
19
|
+
out = jnp.transpose(out.sample, (0, 2, 3, 1))
|
20
|
+
return out
|
21
|
+
|
@@ -20,7 +20,6 @@ class Unet(nn.Module):
|
|
20
20
|
dtype: Optional[Dtype] = None
|
21
21
|
precision: PrecisionLike = None
|
22
22
|
named_norms: bool = False # This is for backward compatibility reasons; older checkpoints have named norms
|
23
|
-
kernel_init: Callable = partial(kernel_init, dtype=jnp.float32)
|
24
23
|
|
25
24
|
def setup(self):
|
26
25
|
if self.norm_groups > 0:
|
@@ -50,7 +49,6 @@ class Unet(nn.Module):
|
|
50
49
|
features=self.feature_depths[0],
|
51
50
|
kernel_size=(3, 3),
|
52
51
|
strides=(1, 1),
|
53
|
-
kernel_init=self.kernel_init(scale=1.0),
|
54
52
|
dtype=self.dtype,
|
55
53
|
precision=self.precision
|
56
54
|
)(x)
|
@@ -65,7 +63,6 @@ class Unet(nn.Module):
|
|
65
63
|
down_conv_type,
|
66
64
|
name=f"down_{i}_residual_{j}",
|
67
65
|
features=dim_in,
|
68
|
-
kernel_init=self.kernel_init(scale=1.0),
|
69
66
|
kernel_size=(3, 3),
|
70
67
|
strides=(1, 1),
|
71
68
|
activation=self.activation,
|
@@ -85,7 +82,6 @@ class Unet(nn.Module):
|
|
85
82
|
force_fp32_for_softmax=attention_config.get("force_fp32_for_softmax", False),
|
86
83
|
norm_inputs=attention_config.get("norm_inputs", True),
|
87
84
|
explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
|
88
|
-
kernel_init=self.kernel_init(scale=1.0),
|
89
85
|
name=f"down_{i}_attention_{j}")(x, textcontext)
|
90
86
|
# print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
|
91
87
|
downs.append(x)
|
@@ -108,7 +104,6 @@ class Unet(nn.Module):
|
|
108
104
|
middle_conv_type,
|
109
105
|
name=f"middle_res1_{j}",
|
110
106
|
features=middle_dim_out,
|
111
|
-
kernel_init=self.kernel_init(scale=1.0),
|
112
107
|
kernel_size=(3, 3),
|
113
108
|
strides=(1, 1),
|
114
109
|
activation=self.activation,
|
@@ -129,13 +124,11 @@ class Unet(nn.Module):
|
|
129
124
|
force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
|
130
125
|
norm_inputs=middle_attention.get("norm_inputs", True),
|
131
126
|
explicitly_add_residual=middle_attention.get("explicitly_add_residual", True),
|
132
|
-
kernel_init=self.kernel_init(scale=1.0),
|
133
127
|
name=f"middle_attention_{j}")(x, textcontext)
|
134
128
|
x = ResidualBlock(
|
135
129
|
middle_conv_type,
|
136
130
|
name=f"middle_res2_{j}",
|
137
131
|
features=middle_dim_out,
|
138
|
-
kernel_init=self.kernel_init(scale=1.0),
|
139
132
|
kernel_size=(3, 3),
|
140
133
|
strides=(1, 1),
|
141
134
|
activation=self.activation,
|
@@ -157,7 +150,6 @@ class Unet(nn.Module):
|
|
157
150
|
up_conv_type,# if j == 0 else "separable",
|
158
151
|
name=f"up_{i}_residual_{j}",
|
159
152
|
features=dim_out,
|
160
|
-
kernel_init=self.kernel_init(scale=1.0),
|
161
153
|
kernel_size=kernel_size,
|
162
154
|
strides=(1, 1),
|
163
155
|
activation=self.activation,
|
@@ -177,7 +169,6 @@ class Unet(nn.Module):
|
|
177
169
|
force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
|
178
170
|
norm_inputs=attention_config.get("norm_inputs", True),
|
179
171
|
explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
|
180
|
-
kernel_init=self.kernel_init(scale=1.0),
|
181
172
|
name=f"up_{i}_attention_{j}")(x, textcontext)
|
182
173
|
# print("Upscaling ", i, x.shape)
|
183
174
|
if i != len(feature_depths) - 1:
|
@@ -196,7 +187,6 @@ class Unet(nn.Module):
|
|
196
187
|
features=self.feature_depths[0],
|
197
188
|
kernel_size=(3, 3),
|
198
189
|
strides=(1, 1),
|
199
|
-
kernel_init=self.kernel_init(scale=1.0),
|
200
190
|
dtype=self.dtype,
|
201
191
|
precision=self.precision
|
202
192
|
)(x)
|
@@ -207,7 +197,6 @@ class Unet(nn.Module):
|
|
207
197
|
conv_type,
|
208
198
|
name="final_residual",
|
209
199
|
features=self.feature_depths[0],
|
210
|
-
kernel_init=self.kernel_init(scale=1.0),
|
211
200
|
kernel_size=(3,3),
|
212
201
|
strides=(1, 1),
|
213
202
|
activation=self.activation,
|
@@ -226,7 +215,7 @@ class Unet(nn.Module):
|
|
226
215
|
kernel_size=(3, 3),
|
227
216
|
strides=(1, 1),
|
228
217
|
# activation=jax.nn.mish
|
229
|
-
kernel_init=self.kernel_init(scale=0.0),
|
218
|
+
# kernel_init=self.kernel_init(scale=0.0),
|
230
219
|
dtype=self.dtype,
|
231
220
|
precision=self.precision
|
232
221
|
)(x)
|
@@ -23,7 +23,6 @@ class PatchEmbedding(nn.Module):
|
|
23
23
|
embedding_dim: int
|
24
24
|
dtype: Any = jnp.float32
|
25
25
|
precision: Any = jax.lax.Precision.HIGH
|
26
|
-
kernel_init: Callable = partial(kernel_init, 1.0)
|
27
26
|
|
28
27
|
@nn.compact
|
29
28
|
def __call__(self, x):
|
@@ -34,7 +33,6 @@ class PatchEmbedding(nn.Module):
|
|
34
33
|
kernel_size=(self.patch_size, self.patch_size),
|
35
34
|
strides=(self.patch_size, self.patch_size),
|
36
35
|
dtype=self.dtype,
|
37
|
-
kernel_init=self.kernel_init(),
|
38
36
|
precision=self.precision)(x)
|
39
37
|
x = jnp.reshape(x, (batch, -1, self.embedding_dim))
|
40
38
|
return x
|
@@ -67,7 +65,7 @@ class UViT(nn.Module):
|
|
67
65
|
norm_groups:int=8
|
68
66
|
dtype: Optional[Dtype] = None
|
69
67
|
precision: PrecisionLike = None
|
70
|
-
kernel_init: Callable = partial(kernel_init, scale=1.0)
|
68
|
+
# kernel_init: Callable = partial(kernel_init, scale=1.0)
|
71
69
|
add_residualblock_output: bool = False
|
72
70
|
norm_inputs: bool = False
|
73
71
|
explicitly_add_residual: bool = True
|
@@ -88,10 +86,10 @@ class UViT(nn.Module):
|
|
88
86
|
|
89
87
|
# Patch embedding
|
90
88
|
x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features,
|
91
|
-
dtype=self.dtype, precision=self.precision
|
89
|
+
dtype=self.dtype, precision=self.precision)(x)
|
92
90
|
num_patches = x.shape[1]
|
93
91
|
|
94
|
-
context_emb = nn.DenseGeneral(features=self.emb_features,
|
92
|
+
context_emb = nn.DenseGeneral(features=self.emb_features,
|
95
93
|
dtype=self.dtype, precision=self.precision)(textcontext)
|
96
94
|
num_text_tokens = textcontext.shape[1]
|
97
95
|
|
@@ -116,7 +114,7 @@ class UViT(nn.Module):
|
|
116
114
|
only_pure_attention=False,
|
117
115
|
norm_inputs=self.norm_inputs,
|
118
116
|
explicitly_add_residual=self.explicitly_add_residual,
|
119
|
-
|
117
|
+
)(x)
|
120
118
|
skips.append(x)
|
121
119
|
|
122
120
|
# Middle block
|
@@ -126,12 +124,12 @@ class UViT(nn.Module):
|
|
126
124
|
only_pure_attention=False,
|
127
125
|
norm_inputs=self.norm_inputs,
|
128
126
|
explicitly_add_residual=self.explicitly_add_residual,
|
129
|
-
|
127
|
+
)(x)
|
130
128
|
|
131
129
|
# # Out blocks
|
132
130
|
for i in range(self.num_layers // 2):
|
133
131
|
x = jnp.concatenate([x, skips.pop()], axis=-1)
|
134
|
-
x = nn.DenseGeneral(features=self.emb_features,
|
132
|
+
x = nn.DenseGeneral(features=self.emb_features,
|
135
133
|
dtype=self.dtype, precision=self.precision)(x)
|
136
134
|
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
|
137
135
|
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
@@ -139,13 +137,13 @@ class UViT(nn.Module):
|
|
139
137
|
only_pure_attention=False,
|
140
138
|
norm_inputs=self.norm_inputs,
|
141
139
|
explicitly_add_residual=self.explicitly_add_residual,
|
142
|
-
|
140
|
+
)(x)
|
143
141
|
|
144
142
|
# print(f'Shape of x after transformer blocks: {x.shape}')
|
145
143
|
x = self.norm()(x)
|
146
144
|
|
147
145
|
patch_dim = self.patch_size ** 2 * self.output_channels
|
148
|
-
x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision
|
146
|
+
x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision)(x)
|
149
147
|
x = x[:, 1 + num_text_tokens:, :]
|
150
148
|
x = unpatchify(x, channels=self.output_channels)
|
151
149
|
|
@@ -159,7 +157,6 @@ class UViT(nn.Module):
|
|
159
157
|
kernel_size=(3, 3),
|
160
158
|
strides=(1, 1),
|
161
159
|
# activation=jax.nn.mish
|
162
|
-
kernel_init=self.kernel_init(scale=0.0),
|
163
160
|
dtype=self.dtype,
|
164
161
|
precision=self.precision
|
165
162
|
)(x)
|
@@ -173,7 +170,6 @@ class UViT(nn.Module):
|
|
173
170
|
kernel_size=(3, 3),
|
174
171
|
strides=(1, 1),
|
175
172
|
# activation=jax.nn.mish
|
176
|
-
kernel_init=self.kernel_init(scale=0.0),
|
177
173
|
dtype=self.dtype,
|
178
174
|
precision=self.precision
|
179
175
|
)(x)
|
@@ -231,11 +231,11 @@ class DiffusionTrainer(SimpleTrainer):
|
|
231
231
|
),
|
232
232
|
)
|
233
233
|
|
234
|
-
|
234
|
+
new_state = new_state.apply_ema(self.ema_decay)
|
235
235
|
|
236
236
|
if distributed_training:
|
237
237
|
loss = jax.lax.pmean(loss, "data")
|
238
|
-
return
|
238
|
+
return new_state, loss, rng_state
|
239
239
|
|
240
240
|
if distributed_training:
|
241
241
|
train_step = shard_map(
|
@@ -159,7 +159,7 @@ class SimpleTrainer:
|
|
159
159
|
self.best_loss = 1e9
|
160
160
|
|
161
161
|
def get_input_ones(self):
|
162
|
-
return {k: jnp.ones((1, *v)) for k, v in self.input_shapes.items()}
|
162
|
+
return {k: jnp.ones((1, *v), dtype=self.model.dtype) for k, v in self.input_shapes.items()}
|
163
163
|
|
164
164
|
def generate_states(
|
165
165
|
self,
|
@@ -437,12 +437,30 @@ class SimpleTrainer:
|
|
437
437
|
# If the loss is too low, we can assume the model has diverged
|
438
438
|
print(colored(f"Loss too low at step {current_step} => {loss}", 'red'))
|
439
439
|
# Reset the model to the old state
|
440
|
-
if self.best_state is not None:
|
441
|
-
|
442
|
-
|
443
|
-
|
440
|
+
# if self.best_state is not None:
|
441
|
+
# print(colored(f"Resetting model to best state", 'red'))
|
442
|
+
# train_state = self.best_state
|
443
|
+
# loss = self.best_loss
|
444
|
+
# else:
|
445
|
+
# exit(1)
|
446
|
+
|
447
|
+
# Check if there are any NaN/inf values in the train_state.params
|
448
|
+
params = train_state.params
|
449
|
+
if isinstance(params, dict):
|
450
|
+
for key, value in params.items():
|
451
|
+
if isinstance(value, jnp.ndarray):
|
452
|
+
if jnp.isnan(value).any() or jnp.isinf(value).any():
|
453
|
+
print(colored(f"NaN/inf values found in params at step {current_step}", 'red'))
|
454
|
+
# Reset the model to the old state
|
455
|
+
# train_state = self.best_state
|
456
|
+
# loss = self.best_loss
|
457
|
+
# break
|
458
|
+
else:
|
459
|
+
print(colored(f"Params are fine at step {current_step}", 'green'))
|
444
460
|
else:
|
445
|
-
|
461
|
+
print(colored(f"Params are not a dict at step {current_step}", 'red'))
|
462
|
+
|
463
|
+
exit(1)
|
446
464
|
|
447
465
|
epoch_loss += loss
|
448
466
|
current_step += 1
|
@@ -21,6 +21,7 @@ flaxdiff/models/__init__.py
|
|
21
21
|
flaxdiff/models/attention.py
|
22
22
|
flaxdiff/models/common.py
|
23
23
|
flaxdiff/models/favor_fastattn.py
|
24
|
+
flaxdiff/models/general.py
|
24
25
|
flaxdiff/models/simple_unet.py
|
25
26
|
flaxdiff/models/simple_vit.py
|
26
27
|
flaxdiff/models/autoencoder/__init__.py
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|