flaxdiff 0.1.35.3__py3-none-any.whl → 0.1.35.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/online_loader.py +2 -2
- flaxdiff/models/attention.py +12 -6
- flaxdiff/models/simple_unet.py +6 -0
- flaxdiff/models/simple_vit.py +10 -2
- flaxdiff/trainer/simple_trainer.py +1 -0
- {flaxdiff-0.1.35.3.dist-info → flaxdiff-0.1.35.5.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.35.3.dist-info → flaxdiff-0.1.35.5.dist-info}/RECORD +9 -9
- {flaxdiff-0.1.35.3.dist-info → flaxdiff-0.1.35.5.dist-info}/WHEEL +1 -1
- {flaxdiff-0.1.35.3.dist-info → flaxdiff-0.1.35.5.dist-info}/top_level.txt +0 -0
flaxdiff/data/online_loader.py
CHANGED
@@ -111,8 +111,8 @@ def map_sample(
|
|
111
111
|
"original_width": original_width,
|
112
112
|
})
|
113
113
|
except Exception as e:
|
114
|
-
print(f"Error maping sample {url}", e)
|
115
|
-
traceback.print_exc()
|
114
|
+
# print(f"Error maping sample {url}", e)
|
115
|
+
# traceback.print_exc()
|
116
116
|
# error_queue.put_nowait({
|
117
117
|
# "url": url,
|
118
118
|
# "caption": caption,
|
flaxdiff/models/attention.py
CHANGED
@@ -303,27 +303,30 @@ class TransformerBlock(nn.Module):
|
|
303
303
|
only_pure_attention:bool = False
|
304
304
|
force_fp32_for_softmax: bool = True
|
305
305
|
kernel_init: Callable = kernel_init(1.0)
|
306
|
+
norm_inputs: bool = True
|
307
|
+
explicitly_add_residual: bool = True
|
306
308
|
|
307
309
|
@nn.compact
|
308
310
|
def __call__(self, x, context=None):
|
309
311
|
inner_dim = self.heads * self.dim_head
|
310
312
|
C = x.shape[-1]
|
311
|
-
|
313
|
+
if self.norm_inputs:
|
314
|
+
x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
|
312
315
|
if self.use_projection == True:
|
313
316
|
if self.use_linear_attention:
|
314
317
|
projected_x = nn.Dense(features=inner_dim,
|
315
318
|
use_bias=False, precision=self.precision,
|
316
319
|
kernel_init=self.kernel_init,
|
317
|
-
dtype=self.dtype, name=f'project_in')(
|
320
|
+
dtype=self.dtype, name=f'project_in')(x)
|
318
321
|
else:
|
319
322
|
projected_x = nn.Conv(
|
320
323
|
features=inner_dim, kernel_size=(1, 1),
|
321
324
|
kernel_init=self.kernel_init,
|
322
325
|
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
323
326
|
precision=self.precision, name=f'project_in_conv',
|
324
|
-
)(
|
327
|
+
)(x)
|
325
328
|
else:
|
326
|
-
projected_x =
|
329
|
+
projected_x = x
|
327
330
|
inner_dim = C
|
328
331
|
|
329
332
|
context = projected_x if context is None else context
|
@@ -356,6 +359,9 @@ class TransformerBlock(nn.Module):
|
|
356
359
|
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
357
360
|
precision=self.precision, name=f'project_out_conv',
|
358
361
|
)(projected_x)
|
359
|
-
|
360
|
-
|
362
|
+
|
363
|
+
if self.only_pure_attention or self.explicitly_add_residual:
|
364
|
+
projected_x = x + projected_x
|
365
|
+
|
366
|
+
out = projected_x
|
361
367
|
return out
|
flaxdiff/models/simple_unet.py
CHANGED
@@ -83,6 +83,8 @@ class Unet(nn.Module):
|
|
83
83
|
precision=attention_config.get("precision", self.precision),
|
84
84
|
only_pure_attention=attention_config.get("only_pure_attention", True),
|
85
85
|
force_fp32_for_softmax=attention_config.get("force_fp32_for_softmax", False),
|
86
|
+
norm_inputs=attention_config.get("norm_inputs", True),
|
87
|
+
explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
|
86
88
|
kernel_init=self.kernel_init(1.0),
|
87
89
|
name=f"down_{i}_attention_{j}")(x, textcontext)
|
88
90
|
# print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
|
@@ -125,6 +127,8 @@ class Unet(nn.Module):
|
|
125
127
|
precision=middle_attention.get("precision", self.precision),
|
126
128
|
only_pure_attention=middle_attention.get("only_pure_attention", True),
|
127
129
|
force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
|
130
|
+
norm_inputs=middle_attention.get("norm_inputs", True),
|
131
|
+
explicitly_add_residual=middle_attention.get("explicitly_add_residual", True),
|
128
132
|
kernel_init=self.kernel_init(1.0),
|
129
133
|
name=f"middle_attention_{j}")(x, textcontext)
|
130
134
|
x = ResidualBlock(
|
@@ -171,6 +175,8 @@ class Unet(nn.Module):
|
|
171
175
|
precision=attention_config.get("precision", self.precision),
|
172
176
|
only_pure_attention=attention_config.get("only_pure_attention", True),
|
173
177
|
force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
|
178
|
+
norm_inputs=attention_config.get("norm_inputs", True),
|
179
|
+
explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
|
174
180
|
kernel_init=self.kernel_init(1.0),
|
175
181
|
name=f"up_{i}_attention_{j}")(x, textcontext)
|
176
182
|
# print("Upscaling ", i, x.shape)
|
flaxdiff/models/simple_vit.py
CHANGED
@@ -69,6 +69,8 @@ class UViT(nn.Module):
|
|
69
69
|
precision: PrecisionLike = None
|
70
70
|
kernel_init: Callable = partial(kernel_init, scale=1.0)
|
71
71
|
add_residualblock_output: bool = False
|
72
|
+
norm_inputs: bool = False
|
73
|
+
explicitly_add_residual: bool = False
|
72
74
|
|
73
75
|
def setup(self):
|
74
76
|
if self.norm_groups > 0:
|
@@ -110,16 +112,20 @@ class UViT(nn.Module):
|
|
110
112
|
for i in range(self.num_layers // 2):
|
111
113
|
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
|
112
114
|
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
113
|
-
use_flash_attention=self.use_flash_attention, use_self_and_cross=
|
115
|
+
use_flash_attention=self.use_flash_attention, use_self_and_cross=False, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
114
116
|
only_pure_attention=False,
|
117
|
+
norm_inputs=self.norm_inputs,
|
118
|
+
explicitly_add_residual=self.explicitly_add_residual,
|
115
119
|
kernel_init=self.kernel_init())(x)
|
116
120
|
skips.append(x)
|
117
121
|
|
118
122
|
# Middle block
|
119
123
|
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
|
120
124
|
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
121
|
-
use_flash_attention=self.use_flash_attention, use_self_and_cross=
|
125
|
+
use_flash_attention=self.use_flash_attention, use_self_and_cross=False, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
122
126
|
only_pure_attention=False,
|
127
|
+
norm_inputs=self.norm_inputs,
|
128
|
+
explicitly_add_residual=self.explicitly_add_residual,
|
123
129
|
kernel_init=self.kernel_init())(x)
|
124
130
|
|
125
131
|
# # Out blocks
|
@@ -131,6 +137,8 @@ class UViT(nn.Module):
|
|
131
137
|
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
132
138
|
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
133
139
|
only_pure_attention=False,
|
140
|
+
norm_inputs=self.norm_inputs,
|
141
|
+
explicitly_add_residual=self.explicitly_add_residual,
|
134
142
|
kernel_init=self.kernel_init())(x)
|
135
143
|
|
136
144
|
# print(f'Shape of x after transformer blocks: {x.shape}')
|
@@ -437,5 +437,6 @@ class SimpleTrainer:
|
|
437
437
|
"train/epoch": current_epoch,
|
438
438
|
}, step=current_step)
|
439
439
|
print(colored(f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}", 'green'))
|
440
|
+
print("Training done")
|
440
441
|
self.save(epochs)
|
441
442
|
return self.state
|
@@ -1,13 +1,13 @@
|
|
1
1
|
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
3
|
flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
|
4
|
-
flaxdiff/data/online_loader.py,sha256=
|
4
|
+
flaxdiff/data/online_loader.py,sha256=DoHrMZCi5gMd9tmkCpZIUU9lGxvfYtuaz58943_lCRc,11315
|
5
5
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
6
|
-
flaxdiff/models/attention.py,sha256=
|
6
|
+
flaxdiff/models/attention.py,sha256=Oz0-F0jllo3cIqQQJwYbYMrkhVzZWf5bq1UC6_RU1r8,13234
|
7
7
|
flaxdiff/models/common.py,sha256=hWsSs2BP2J-JN1s4qLRr-h-KYkcVyl2hOp1Wsm_L-h8,10994
|
8
8
|
flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
|
9
|
-
flaxdiff/models/simple_unet.py,sha256=
|
10
|
-
flaxdiff/models/simple_vit.py,sha256=
|
9
|
+
flaxdiff/models/simple_unet.py,sha256=EmTmf3wLIVULGpL1cSmAPhvg7RW4U8Ff3E8Qrtt8RLY,11397
|
10
|
+
flaxdiff/models/simple_vit.py,sha256=YQiK5AHMjNxH5DAF6Oy4kjRPqddrvUP4i-CuavASPPI,7930
|
11
11
|
flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
|
12
12
|
flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
|
13
13
|
flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
|
@@ -33,8 +33,8 @@ flaxdiff/schedulers/sqrt.py,sha256=1F84ZgQPuoNMhe6yxGTR2G0h7dPOZtm4UDQOakbSsEU,4
|
|
33
33
|
flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,128
|
34
34
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
35
35
|
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
36
|
-
flaxdiff/trainer/simple_trainer.py,sha256=
|
37
|
-
flaxdiff-0.1.35.
|
38
|
-
flaxdiff-0.1.35.
|
39
|
-
flaxdiff-0.1.35.
|
40
|
-
flaxdiff-0.1.35.
|
36
|
+
flaxdiff/trainer/simple_trainer.py,sha256=cawm6fZNQoLLATMneAU2gQ9j7kefqHnBPHuaIj3i_a4,18237
|
37
|
+
flaxdiff-0.1.35.5.dist-info/METADATA,sha256=J8ikdMf9LIzdAuk6E3E77sI8NPegdgaKf5mrDHhpdXc,22085
|
38
|
+
flaxdiff-0.1.35.5.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
39
|
+
flaxdiff-0.1.35.5.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
40
|
+
flaxdiff-0.1.35.5.dist-info/RECORD,,
|
File without changes
|