flaxdiff 0.1.38__tar.gz → 0.1.38.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/PKG-INFO +1 -1
  2. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/data/sources/tfds.py +7 -7
  3. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/models/attention.py +22 -16
  4. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/models/common.py +8 -18
  5. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/models/simple_unet.py +1 -12
  6. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/models/simple_vit.py +8 -12
  7. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/trainer/diffusion_trainer.py +2 -2
  8. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/trainer/simple_trainer.py +24 -6
  9. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff.egg-info/PKG-INFO +1 -1
  10. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/pyproject.toml +1 -1
  11. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/README.md +0 -0
  12. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/__init__.py +0 -0
  13. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/data/__init__.py +0 -0
  14. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/data/dataset_map.py +0 -0
  15. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/data/datasets.py +0 -0
  16. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/data/online_loader.py +0 -0
  17. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/data/sources/gcs.py +0 -0
  18. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/metrics/inception.py +0 -0
  19. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/metrics/psnr.py +0 -0
  20. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/metrics/ssim.py +0 -0
  21. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/metrics/utils.py +0 -0
  22. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/models/__init__.py +0 -0
  23. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/models/autoencoder/__init__.py +0 -0
  24. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
  25. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/models/autoencoder/diffusers.py +0 -0
  26. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
  27. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/models/favor_fastattn.py +0 -0
  28. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/models/general.py +0 -0
  29. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/predictors/__init__.py +0 -0
  30. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/samplers/__init__.py +0 -0
  31. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/samplers/common.py +0 -0
  32. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/samplers/ddim.py +0 -0
  33. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/samplers/ddpm.py +0 -0
  34. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/samplers/euler.py +0 -0
  35. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/samplers/heun_sampler.py +0 -0
  36. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/samplers/multistep_dpm.py +0 -0
  37. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/samplers/rk4_sampler.py +0 -0
  38. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/schedulers/__init__.py +0 -0
  39. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/schedulers/common.py +0 -0
  40. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/schedulers/continuous.py +0 -0
  41. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/schedulers/cosine.py +0 -0
  42. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/schedulers/discrete.py +0 -0
  43. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/schedulers/exp.py +0 -0
  44. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/schedulers/karras.py +0 -0
  45. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/schedulers/linear.py +0 -0
  46. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/schedulers/sqrt.py +0 -0
  47. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/trainer/__init__.py +0 -0
  48. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
  49. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/trainer/video_diffusion_trainer.py +0 -0
  50. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff/utils.py +0 -0
  51. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff.egg-info/SOURCES.txt +0 -0
  52. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff.egg-info/dependency_links.txt +0 -0
  53. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff.egg-info/requires.txt +0 -0
  54. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/flaxdiff.egg-info/top_level.txt +0 -0
  55. {flaxdiff-0.1.38 → flaxdiff-0.1.38.1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.38
3
+ Version: 0.1.38.1
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -50,13 +50,12 @@ def tfds_augmenters(image_scale, method):
50
50
  else:
51
51
  interpolation = cv2.INTER_AREA
52
52
 
53
- augments = augmax.Chain(
54
- augmax.HorizontalFlip(0.5),
55
- augmax.RandomContrast((-0.05, 0.05), 1.),
56
- augmax.RandomBrightness((-0.2, 0.2), 1.)
57
- )
53
+ from torchvision.transforms import v2
58
54
 
59
- augments = jax.jit(augments, backend="cpu")
55
+ augments = v2.Compose([
56
+ v2.RandomHorizontalFlip(p=0.5),
57
+ v2.ColorJitter(brightness=0.2, contrast=0.05, saturation=0.2)
58
+ ])
60
59
 
61
60
  class augmenters(pygrain.MapTransform):
62
61
  def __init__(self, *args, **kwargs):
@@ -67,8 +66,9 @@ def tfds_augmenters(image_scale, method):
67
66
  image = element['image']
68
67
  image = cv2.resize(image, (image_scale, image_scale),
69
68
  interpolation=interpolation)
70
- # image = augments(image)
69
+ image = augments(image)
71
70
  # image = (image - 127.5) / 127.5
71
+
72
72
  caption = labelizer(element)
73
73
  results = self.tokenize(caption)
74
74
  return {
@@ -23,7 +23,7 @@ class EfficientAttention(nn.Module):
23
23
  dtype: Optional[Dtype] = None
24
24
  precision: PrecisionLike = None
25
25
  use_bias: bool = True
26
- kernel_init: Callable = kernel_init(1.0)
26
+ # kernel_init: Callable = kernel_init(1.0)
27
27
  force_fp32_for_softmax: bool = True
28
28
 
29
29
  def setup(self):
@@ -34,15 +34,21 @@ class EfficientAttention(nn.Module):
34
34
  self.heads * self.dim_head,
35
35
  precision=self.precision,
36
36
  use_bias=self.use_bias,
37
- kernel_init=self.kernel_init,
37
+ # kernel_init=self.kernel_init,
38
38
  dtype=self.dtype
39
39
  )
40
40
  self.query = dense(name="to_q")
41
41
  self.key = dense(name="to_k")
42
42
  self.value = dense(name="to_v")
43
43
 
44
- self.proj_attn = nn.DenseGeneral(self.query_dim, use_bias=False, precision=self.precision,
45
- kernel_init=self.kernel_init, dtype=self.dtype, name="to_out_0")
44
+ self.proj_attn = nn.DenseGeneral(
45
+ self.query_dim,
46
+ use_bias=False,
47
+ precision=self.precision,
48
+ # kernel_init=self.kernel_init,
49
+ dtype=self.dtype,
50
+ name="to_out_0"
51
+ )
46
52
  # self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16)
47
53
 
48
54
  def _reshape_tensor_to_head_dim(self, tensor):
@@ -115,7 +121,7 @@ class NormalAttention(nn.Module):
115
121
  dtype: Optional[Dtype] = None
116
122
  precision: PrecisionLike = None
117
123
  use_bias: bool = True
118
- kernel_init: Callable = kernel_init(1.0)
124
+ # kernel_init: Callable = kernel_init(1.0)
119
125
  force_fp32_for_softmax: bool = True
120
126
 
121
127
  def setup(self):
@@ -126,7 +132,7 @@ class NormalAttention(nn.Module):
126
132
  axis=-1,
127
133
  precision=self.precision,
128
134
  use_bias=self.use_bias,
129
- kernel_init=self.kernel_init,
135
+ # kernel_init=self.kernel_init,
130
136
  dtype=self.dtype
131
137
  )
132
138
  self.query = dense(name="to_q")
@@ -140,7 +146,7 @@ class NormalAttention(nn.Module):
140
146
  use_bias=self.use_bias,
141
147
  dtype=self.dtype,
142
148
  name="to_out_0",
143
- kernel_init=self.kernel_init
149
+ # kernel_init=self.kernel_init
144
150
  # kernel_init=jax.nn.initializers.xavier_uniform()
145
151
  )
146
152
 
@@ -236,7 +242,7 @@ class BasicTransformerBlock(nn.Module):
236
242
  dtype: Optional[Dtype] = None
237
243
  precision: PrecisionLike = None
238
244
  use_bias: bool = True
239
- kernel_init: Callable = kernel_init(1.0)
245
+ # kernel_init: Callable = kernel_init(1.0)
240
246
  use_flash_attention:bool = False
241
247
  use_cross_only:bool = False
242
248
  only_pure_attention:bool = False
@@ -256,7 +262,7 @@ class BasicTransformerBlock(nn.Module):
256
262
  precision=self.precision,
257
263
  use_bias=self.use_bias,
258
264
  dtype=self.dtype,
259
- kernel_init=self.kernel_init,
265
+ # kernel_init=self.kernel_init,
260
266
  force_fp32_for_softmax=self.force_fp32_for_softmax
261
267
  )
262
268
  self.attention2 = attenBlock(
@@ -267,7 +273,7 @@ class BasicTransformerBlock(nn.Module):
267
273
  precision=self.precision,
268
274
  use_bias=self.use_bias,
269
275
  dtype=self.dtype,
270
- kernel_init=self.kernel_init,
276
+ # kernel_init=self.kernel_init,
271
277
  force_fp32_for_softmax=self.force_fp32_for_softmax
272
278
  )
273
279
 
@@ -303,7 +309,7 @@ class TransformerBlock(nn.Module):
303
309
  use_self_and_cross:bool = True
304
310
  only_pure_attention:bool = False
305
311
  force_fp32_for_softmax: bool = True
306
- kernel_init: Callable = kernel_init(1.0)
312
+ # kernel_init: Callable = kernel_init(1.0)
307
313
  norm_inputs: bool = True
308
314
  explicitly_add_residual: bool = True
309
315
 
@@ -317,12 +323,12 @@ class TransformerBlock(nn.Module):
317
323
  if self.use_linear_attention:
318
324
  projected_x = nn.Dense(features=inner_dim,
319
325
  use_bias=False, precision=self.precision,
320
- kernel_init=self.kernel_init,
326
+ # kernel_init=self.kernel_init,
321
327
  dtype=self.dtype, name=f'project_in')(x)
322
328
  else:
323
329
  projected_x = nn.Conv(
324
330
  features=inner_dim, kernel_size=(1, 1),
325
- kernel_init=self.kernel_init,
331
+ # kernel_init=self.kernel_init,
326
332
  strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
327
333
  precision=self.precision, name=f'project_in_conv',
328
334
  )(x)
@@ -344,19 +350,19 @@ class TransformerBlock(nn.Module):
344
350
  use_cross_only=(not self.use_self_and_cross),
345
351
  only_pure_attention=self.only_pure_attention,
346
352
  force_fp32_for_softmax=self.force_fp32_for_softmax,
347
- kernel_init=self.kernel_init
353
+ # kernel_init=self.kernel_init
348
354
  )(projected_x, context)
349
355
 
350
356
  if self.use_projection == True:
351
357
  if self.use_linear_attention:
352
358
  projected_x = nn.Dense(features=C, precision=self.precision,
353
359
  dtype=self.dtype, use_bias=False,
354
- kernel_init=self.kernel_init,
360
+ # kernel_init=self.kernel_init,
355
361
  name=f'project_out')(projected_x)
356
362
  else:
357
363
  projected_x = nn.Conv(
358
364
  features=C, kernel_size=(1, 1),
359
- kernel_init=self.kernel_init,
365
+ # kernel_init=self.kernel_i nit,
360
366
  strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
361
367
  precision=self.precision, name=f'project_out_conv',
362
368
  )(projected_x)
@@ -108,13 +108,16 @@ class FourierEmbedding(nn.Module):
108
108
  class TimeProjection(nn.Module):
109
109
  features:int
110
110
  activation:Callable=jax.nn.gelu
111
- kernel_init:Callable=kernel_init(1.0)
112
111
 
113
112
  @nn.compact
114
113
  def __call__(self, x):
115
- x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init)(x)
114
+ x = nn.DenseGeneral(
115
+ self.features,
116
+ )(x)
116
117
  x = self.activation(x)
117
- x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init)(x)
118
+ x = nn.DenseGeneral(
119
+ self.features,
120
+ )(x)
118
121
  x = self.activation(x)
119
122
  return x
120
123
 
@@ -123,7 +126,6 @@ class SeparableConv(nn.Module):
123
126
  kernel_size:tuple=(3, 3)
124
127
  strides:tuple=(1, 1)
125
128
  use_bias:bool=False
126
- kernel_init:Callable=kernel_init(1.0)
127
129
  padding:str="SAME"
128
130
  dtype: Optional[Dtype] = None
129
131
  precision: PrecisionLike = None
@@ -133,7 +135,7 @@ class SeparableConv(nn.Module):
133
135
  in_features = x.shape[-1]
134
136
  depthwise = nn.Conv(
135
137
  features=in_features, kernel_size=self.kernel_size,
136
- strides=self.strides, kernel_init=self.kernel_init,
138
+ strides=self.strides,
137
139
  feature_group_count=in_features, use_bias=self.use_bias,
138
140
  padding=self.padding,
139
141
  dtype=self.dtype,
@@ -141,7 +143,7 @@ class SeparableConv(nn.Module):
141
143
  )(x)
142
144
  pointwise = nn.Conv(
143
145
  features=self.features, kernel_size=(1, 1),
144
- strides=(1, 1), kernel_init=self.kernel_init,
146
+ strides=(1, 1),
145
147
  use_bias=self.use_bias,
146
148
  dtype=self.dtype,
147
149
  precision=self.precision
@@ -153,7 +155,6 @@ class ConvLayer(nn.Module):
153
155
  features:int
154
156
  kernel_size:tuple=(3, 3)
155
157
  strides:tuple=(1, 1)
156
- kernel_init:Callable=kernel_init(1.0)
157
158
  dtype: Optional[Dtype] = None
158
159
  precision: PrecisionLike = None
159
160
 
@@ -164,7 +165,6 @@ class ConvLayer(nn.Module):
164
165
  features=self.features,
165
166
  kernel_size=self.kernel_size,
166
167
  strides=self.strides,
167
- kernel_init=self.kernel_init,
168
168
  dtype=self.dtype,
169
169
  precision=self.precision
170
170
  )
@@ -183,7 +183,6 @@ class ConvLayer(nn.Module):
183
183
  features=self.features,
184
184
  kernel_size=self.kernel_size,
185
185
  strides=self.strides,
186
- kernel_init=self.kernel_init,
187
186
  dtype=self.dtype,
188
187
  precision=self.precision
189
188
  )
@@ -192,7 +191,6 @@ class ConvLayer(nn.Module):
192
191
  features=self.features,
193
192
  kernel_size=self.kernel_size,
194
193
  strides=self.strides,
195
- kernel_init=self.kernel_init,
196
194
  dtype=self.dtype,
197
195
  precision=self.precision
198
196
  )
@@ -206,7 +204,6 @@ class Upsample(nn.Module):
206
204
  activation:Callable=jax.nn.swish
207
205
  dtype: Optional[Dtype] = None
208
206
  precision: PrecisionLike = None
209
- kernel_init:Callable=kernel_init(1.0)
210
207
 
211
208
  @nn.compact
212
209
  def __call__(self, x, residual=None):
@@ -221,7 +218,6 @@ class Upsample(nn.Module):
221
218
  strides=(1, 1),
222
219
  dtype=self.dtype,
223
220
  precision=self.precision,
224
- kernel_init=self.kernel_init
225
221
  )(out)
226
222
  if residual is not None:
227
223
  out = jnp.concatenate([out, residual], axis=-1)
@@ -233,7 +229,6 @@ class Downsample(nn.Module):
233
229
  activation:Callable=jax.nn.swish
234
230
  dtype: Optional[Dtype] = None
235
231
  precision: PrecisionLike = None
236
- kernel_init:Callable=kernel_init(1.0)
237
232
 
238
233
  @nn.compact
239
234
  def __call__(self, x, residual=None):
@@ -244,7 +239,6 @@ class Downsample(nn.Module):
244
239
  strides=(2, 2),
245
240
  dtype=self.dtype,
246
241
  precision=self.precision,
247
- kernel_init=self.kernel_init
248
242
  )(x)
249
243
  if residual is not None:
250
244
  if residual.shape[1] > out.shape[1]:
@@ -269,7 +263,6 @@ class ResidualBlock(nn.Module):
269
263
  direction:str=None
270
264
  res:int=2
271
265
  norm_groups:int=8
272
- kernel_init:Callable=kernel_init(1.0)
273
266
  dtype: Optional[Dtype] = None
274
267
  precision: PrecisionLike = None
275
268
  named_norms:bool=False
@@ -296,7 +289,6 @@ class ResidualBlock(nn.Module):
296
289
  features=self.features,
297
290
  kernel_size=self.kernel_size,
298
291
  strides=self.strides,
299
- kernel_init=self.kernel_init,
300
292
  name="conv1",
301
293
  dtype=self.dtype,
302
294
  precision=self.precision
@@ -321,7 +313,6 @@ class ResidualBlock(nn.Module):
321
313
  features=self.features,
322
314
  kernel_size=self.kernel_size,
323
315
  strides=self.strides,
324
- kernel_init=self.kernel_init,
325
316
  name="conv2",
326
317
  dtype=self.dtype,
327
318
  precision=self.precision
@@ -333,7 +324,6 @@ class ResidualBlock(nn.Module):
333
324
  features=self.features,
334
325
  kernel_size=(1, 1),
335
326
  strides=1,
336
- kernel_init=self.kernel_init,
337
327
  name="residual_conv",
338
328
  dtype=self.dtype,
339
329
  precision=self.precision
@@ -20,7 +20,6 @@ class Unet(nn.Module):
20
20
  dtype: Optional[Dtype] = None
21
21
  precision: PrecisionLike = None
22
22
  named_norms: bool = False # This is for backward compatibility reasons; older checkpoints have named norms
23
- kernel_init: Callable = partial(kernel_init, dtype=jnp.float32)
24
23
 
25
24
  def setup(self):
26
25
  if self.norm_groups > 0:
@@ -50,7 +49,6 @@ class Unet(nn.Module):
50
49
  features=self.feature_depths[0],
51
50
  kernel_size=(3, 3),
52
51
  strides=(1, 1),
53
- kernel_init=self.kernel_init(scale=1.0),
54
52
  dtype=self.dtype,
55
53
  precision=self.precision
56
54
  )(x)
@@ -65,7 +63,6 @@ class Unet(nn.Module):
65
63
  down_conv_type,
66
64
  name=f"down_{i}_residual_{j}",
67
65
  features=dim_in,
68
- kernel_init=self.kernel_init(scale=1.0),
69
66
  kernel_size=(3, 3),
70
67
  strides=(1, 1),
71
68
  activation=self.activation,
@@ -85,7 +82,6 @@ class Unet(nn.Module):
85
82
  force_fp32_for_softmax=attention_config.get("force_fp32_for_softmax", False),
86
83
  norm_inputs=attention_config.get("norm_inputs", True),
87
84
  explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
88
- kernel_init=self.kernel_init(scale=1.0),
89
85
  name=f"down_{i}_attention_{j}")(x, textcontext)
90
86
  # print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
91
87
  downs.append(x)
@@ -108,7 +104,6 @@ class Unet(nn.Module):
108
104
  middle_conv_type,
109
105
  name=f"middle_res1_{j}",
110
106
  features=middle_dim_out,
111
- kernel_init=self.kernel_init(scale=1.0),
112
107
  kernel_size=(3, 3),
113
108
  strides=(1, 1),
114
109
  activation=self.activation,
@@ -129,13 +124,11 @@ class Unet(nn.Module):
129
124
  force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
130
125
  norm_inputs=middle_attention.get("norm_inputs", True),
131
126
  explicitly_add_residual=middle_attention.get("explicitly_add_residual", True),
132
- kernel_init=self.kernel_init(scale=1.0),
133
127
  name=f"middle_attention_{j}")(x, textcontext)
134
128
  x = ResidualBlock(
135
129
  middle_conv_type,
136
130
  name=f"middle_res2_{j}",
137
131
  features=middle_dim_out,
138
- kernel_init=self.kernel_init(scale=1.0),
139
132
  kernel_size=(3, 3),
140
133
  strides=(1, 1),
141
134
  activation=self.activation,
@@ -157,7 +150,6 @@ class Unet(nn.Module):
157
150
  up_conv_type,# if j == 0 else "separable",
158
151
  name=f"up_{i}_residual_{j}",
159
152
  features=dim_out,
160
- kernel_init=self.kernel_init(scale=1.0),
161
153
  kernel_size=kernel_size,
162
154
  strides=(1, 1),
163
155
  activation=self.activation,
@@ -177,7 +169,6 @@ class Unet(nn.Module):
177
169
  force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
178
170
  norm_inputs=attention_config.get("norm_inputs", True),
179
171
  explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
180
- kernel_init=self.kernel_init(scale=1.0),
181
172
  name=f"up_{i}_attention_{j}")(x, textcontext)
182
173
  # print("Upscaling ", i, x.shape)
183
174
  if i != len(feature_depths) - 1:
@@ -196,7 +187,6 @@ class Unet(nn.Module):
196
187
  features=self.feature_depths[0],
197
188
  kernel_size=(3, 3),
198
189
  strides=(1, 1),
199
- kernel_init=self.kernel_init(scale=1.0),
200
190
  dtype=self.dtype,
201
191
  precision=self.precision
202
192
  )(x)
@@ -207,7 +197,6 @@ class Unet(nn.Module):
207
197
  conv_type,
208
198
  name="final_residual",
209
199
  features=self.feature_depths[0],
210
- kernel_init=self.kernel_init(scale=1.0),
211
200
  kernel_size=(3,3),
212
201
  strides=(1, 1),
213
202
  activation=self.activation,
@@ -226,7 +215,7 @@ class Unet(nn.Module):
226
215
  kernel_size=(3, 3),
227
216
  strides=(1, 1),
228
217
  # activation=jax.nn.mish
229
- kernel_init=self.kernel_init(scale=0.0),
218
+ # kernel_init=self.kernel_init(scale=0.0),
230
219
  dtype=self.dtype,
231
220
  precision=self.precision
232
221
  )(x)
@@ -23,7 +23,6 @@ class PatchEmbedding(nn.Module):
23
23
  embedding_dim: int
24
24
  dtype: Any = jnp.float32
25
25
  precision: Any = jax.lax.Precision.HIGH
26
- kernel_init: Callable = partial(kernel_init, 1.0)
27
26
 
28
27
  @nn.compact
29
28
  def __call__(self, x):
@@ -34,7 +33,6 @@ class PatchEmbedding(nn.Module):
34
33
  kernel_size=(self.patch_size, self.patch_size),
35
34
  strides=(self.patch_size, self.patch_size),
36
35
  dtype=self.dtype,
37
- kernel_init=self.kernel_init(),
38
36
  precision=self.precision)(x)
39
37
  x = jnp.reshape(x, (batch, -1, self.embedding_dim))
40
38
  return x
@@ -67,7 +65,7 @@ class UViT(nn.Module):
67
65
  norm_groups:int=8
68
66
  dtype: Optional[Dtype] = None
69
67
  precision: PrecisionLike = None
70
- kernel_init: Callable = partial(kernel_init, scale=1.0)
68
+ # kernel_init: Callable = partial(kernel_init, scale=1.0)
71
69
  add_residualblock_output: bool = False
72
70
  norm_inputs: bool = False
73
71
  explicitly_add_residual: bool = True
@@ -88,10 +86,10 @@ class UViT(nn.Module):
88
86
 
89
87
  # Patch embedding
90
88
  x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features,
91
- dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init)(x)
89
+ dtype=self.dtype, precision=self.precision)(x)
92
90
  num_patches = x.shape[1]
93
91
 
94
- context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
92
+ context_emb = nn.DenseGeneral(features=self.emb_features,
95
93
  dtype=self.dtype, precision=self.precision)(textcontext)
96
94
  num_text_tokens = textcontext.shape[1]
97
95
 
@@ -116,7 +114,7 @@ class UViT(nn.Module):
116
114
  only_pure_attention=False,
117
115
  norm_inputs=self.norm_inputs,
118
116
  explicitly_add_residual=self.explicitly_add_residual,
119
- kernel_init=self.kernel_init())(x)
117
+ )(x)
120
118
  skips.append(x)
121
119
 
122
120
  # Middle block
@@ -126,12 +124,12 @@ class UViT(nn.Module):
126
124
  only_pure_attention=False,
127
125
  norm_inputs=self.norm_inputs,
128
126
  explicitly_add_residual=self.explicitly_add_residual,
129
- kernel_init=self.kernel_init())(x)
127
+ )(x)
130
128
 
131
129
  # # Out blocks
132
130
  for i in range(self.num_layers // 2):
133
131
  x = jnp.concatenate([x, skips.pop()], axis=-1)
134
- x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
132
+ x = nn.DenseGeneral(features=self.emb_features,
135
133
  dtype=self.dtype, precision=self.precision)(x)
136
134
  x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
137
135
  dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
@@ -139,13 +137,13 @@ class UViT(nn.Module):
139
137
  only_pure_attention=False,
140
138
  norm_inputs=self.norm_inputs,
141
139
  explicitly_add_residual=self.explicitly_add_residual,
142
- kernel_init=self.kernel_init())(x)
140
+ )(x)
143
141
 
144
142
  # print(f'Shape of x after transformer blocks: {x.shape}')
145
143
  x = self.norm()(x)
146
144
 
147
145
  patch_dim = self.patch_size ** 2 * self.output_channels
148
- x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init())(x)
146
+ x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision)(x)
149
147
  x = x[:, 1 + num_text_tokens:, :]
150
148
  x = unpatchify(x, channels=self.output_channels)
151
149
 
@@ -159,7 +157,6 @@ class UViT(nn.Module):
159
157
  kernel_size=(3, 3),
160
158
  strides=(1, 1),
161
159
  # activation=jax.nn.mish
162
- kernel_init=self.kernel_init(scale=0.0),
163
160
  dtype=self.dtype,
164
161
  precision=self.precision
165
162
  )(x)
@@ -173,7 +170,6 @@ class UViT(nn.Module):
173
170
  kernel_size=(3, 3),
174
171
  strides=(1, 1),
175
172
  # activation=jax.nn.mish
176
- kernel_init=self.kernel_init(scale=0.0),
177
173
  dtype=self.dtype,
178
174
  precision=self.precision
179
175
  )(x)
@@ -231,11 +231,11 @@ class DiffusionTrainer(SimpleTrainer):
231
231
  ),
232
232
  )
233
233
 
234
- train_state = new_state.apply_ema(self.ema_decay)
234
+ new_state = new_state.apply_ema(self.ema_decay)
235
235
 
236
236
  if distributed_training:
237
237
  loss = jax.lax.pmean(loss, "data")
238
- return train_state, loss, rng_state
238
+ return new_state, loss, rng_state
239
239
 
240
240
  if distributed_training:
241
241
  train_step = shard_map(
@@ -159,7 +159,7 @@ class SimpleTrainer:
159
159
  self.best_loss = 1e9
160
160
 
161
161
  def get_input_ones(self):
162
- return {k: jnp.ones((1, *v)) for k, v in self.input_shapes.items()}
162
+ return {k: jnp.ones((1, *v), dtype=self.model.dtype) for k, v in self.input_shapes.items()}
163
163
 
164
164
  def generate_states(
165
165
  self,
@@ -437,12 +437,30 @@ class SimpleTrainer:
437
437
  # If the loss is too low, we can assume the model has diverged
438
438
  print(colored(f"Loss too low at step {current_step} => {loss}", 'red'))
439
439
  # Reset the model to the old state
440
- if self.best_state is not None:
441
- print(colored(f"Resetting model to best state", 'red'))
442
- train_state = self.best_state
443
- loss = self.best_loss
440
+ # if self.best_state is not None:
441
+ # print(colored(f"Resetting model to best state", 'red'))
442
+ # train_state = self.best_state
443
+ # loss = self.best_loss
444
+ # else:
445
+ # exit(1)
446
+
447
+ # Check if there are any NaN/inf values in the train_state.params
448
+ params = train_state.params
449
+ if isinstance(params, dict):
450
+ for key, value in params.items():
451
+ if isinstance(value, jnp.ndarray):
452
+ if jnp.isnan(value).any() or jnp.isinf(value).any():
453
+ print(colored(f"NaN/inf values found in params at step {current_step}", 'red'))
454
+ # Reset the model to the old state
455
+ # train_state = self.best_state
456
+ # loss = self.best_loss
457
+ # break
458
+ else:
459
+ print(colored(f"Params are fine at step {current_step}", 'green'))
444
460
  else:
445
- exit(1)
461
+ print(colored(f"Params are not a dict at step {current_step}", 'red'))
462
+
463
+ exit(1)
446
464
 
447
465
  epoch_loss += loss
448
466
  current_step += 1
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.38
3
+ Version: 0.1.38.1
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "flaxdiff"
7
- version = "0.1.38"
7
+ version = "0.1.38.1"
8
8
  description = "A versatile and easy to understand Diffusion library"
9
9
  readme = "README.md"
10
10
  authors = [
File without changes
File without changes
File without changes