flaxdiff 0.1.9__tar.gz → 0.1.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/PKG-INFO +1 -1
  2. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/models/attention.py +17 -8
  3. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/models/common.py +5 -3
  4. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/models/simple_unet.py +28 -16
  5. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/trainer/diffusion_trainer.py +4 -2
  6. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/trainer/simple_trainer.py +6 -3
  7. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff.egg-info/PKG-INFO +1 -1
  8. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/setup.py +1 -1
  9. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/README.md +0 -0
  10. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/__init__.py +0 -0
  11. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/models/__init__.py +0 -0
  12. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/models/autoencoder/__init__.py +0 -0
  13. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
  14. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/models/autoencoder/diffusers.py +0 -0
  15. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
  16. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/models/favor_fastattn.py +0 -0
  17. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/models/simple_vit.py +0 -0
  18. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/predictors/__init__.py +0 -0
  19. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/samplers/__init__.py +0 -0
  20. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/samplers/common.py +0 -0
  21. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/samplers/ddim.py +0 -0
  22. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/samplers/ddpm.py +0 -0
  23. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/samplers/euler.py +0 -0
  24. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/samplers/heun_sampler.py +0 -0
  25. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/samplers/multistep_dpm.py +0 -0
  26. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/samplers/rk4_sampler.py +0 -0
  27. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/schedulers/__init__.py +0 -0
  28. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/schedulers/common.py +0 -0
  29. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/schedulers/continuous.py +0 -0
  30. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/schedulers/cosine.py +0 -0
  31. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/schedulers/discrete.py +0 -0
  32. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/schedulers/exp.py +0 -0
  33. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/schedulers/karras.py +0 -0
  34. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/schedulers/linear.py +0 -0
  35. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/schedulers/sqrt.py +0 -0
  36. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/trainer/__init__.py +0 -0
  37. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
  38. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/utils.py +0 -0
  39. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff.egg-info/SOURCES.txt +0 -0
  40. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff.egg-info/dependency_links.txt +0 -0
  41. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff.egg-info/requires.txt +0 -0
  42. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff.egg-info/top_level.txt +0 -0
  43. {flaxdiff-0.1.9 → flaxdiff-0.1.11}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.9
3
+ Version: 0.1.11
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -23,6 +23,7 @@ class EfficientAttention(nn.Module):
23
23
  precision: PrecisionLike = None
24
24
  use_bias: bool = True
25
25
  kernel_init: Callable = lambda : kernel_init(1.0)
26
+ force_fp32_for_softmax: bool = True
26
27
 
27
28
  def setup(self):
28
29
  inner_dim = self.dim_head * self.heads
@@ -114,6 +115,7 @@ class NormalAttention(nn.Module):
114
115
  precision: PrecisionLike = None
115
116
  use_bias: bool = True
116
117
  kernel_init: Callable = lambda : kernel_init(1.0)
118
+ force_fp32_for_softmax: bool = True
117
119
 
118
120
  def setup(self):
119
121
  inner_dim = self.dim_head * self.heads
@@ -157,7 +159,7 @@ class NormalAttention(nn.Module):
157
159
 
158
160
  hidden_states = nn.dot_product_attention(
159
161
  query, key, value, dtype=self.dtype, broadcast_dropout=False,
160
- dropout_rng=None, precision=self.precision, force_fp32_for_softmax=True,
162
+ dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,
161
163
  deterministic=True
162
164
  )
163
165
  proj = self.proj_attn(hidden_states)
@@ -237,6 +239,7 @@ class BasicTransformerBlock(nn.Module):
237
239
  use_flash_attention:bool = False
238
240
  use_cross_only:bool = False
239
241
  only_pure_attention:bool = False
242
+ force_fp32_for_softmax: bool = True
240
243
 
241
244
  def setup(self):
242
245
  if self.use_flash_attention:
@@ -252,7 +255,8 @@ class BasicTransformerBlock(nn.Module):
252
255
  precision=self.precision,
253
256
  use_bias=self.use_bias,
254
257
  dtype=self.dtype,
255
- kernel_init=self.kernel_init
258
+ kernel_init=self.kernel_init,
259
+ force_fp32_for_softmax=self.force_fp32_for_softmax
256
260
  )
257
261
  self.attention2 = attenBlock(
258
262
  query_dim=self.query_dim,
@@ -262,7 +266,8 @@ class BasicTransformerBlock(nn.Module):
262
266
  precision=self.precision,
263
267
  use_bias=self.use_bias,
264
268
  dtype=self.dtype,
265
- kernel_init=self.kernel_init
269
+ kernel_init=self.kernel_init,
270
+ force_fp32_for_softmax=self.force_fp32_for_softmax
266
271
  )
267
272
 
268
273
  self.ff = FlaxFeedForward(dim=self.query_dim)
@@ -296,6 +301,8 @@ class TransformerBlock(nn.Module):
296
301
  use_flash_attention:bool = False
297
302
  use_self_and_cross:bool = True
298
303
  only_pure_attention:bool = False
304
+ force_fp32_for_softmax: bool = True
305
+ kernel_init: Callable = lambda : kernel_init(1.0)
299
306
 
300
307
  @nn.compact
301
308
  def __call__(self, x, context=None):
@@ -306,12 +313,12 @@ class TransformerBlock(nn.Module):
306
313
  if self.use_linear_attention:
307
314
  projected_x = nn.Dense(features=inner_dim,
308
315
  use_bias=False, precision=self.precision,
309
- kernel_init=kernel_init(1.0),
316
+ kernel_init=self.kernel_init(),
310
317
  dtype=self.dtype, name=f'project_in')(normed_x)
311
318
  else:
312
319
  projected_x = nn.Conv(
313
320
  features=inner_dim, kernel_size=(1, 1),
314
- kernel_init=kernel_init(1.0),
321
+ kernel_init=self.kernel_init(),
315
322
  strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
316
323
  precision=self.precision, name=f'project_in_conv',
317
324
  )(normed_x)
@@ -331,19 +338,21 @@ class TransformerBlock(nn.Module):
331
338
  dtype=self.dtype,
332
339
  use_flash_attention=self.use_flash_attention,
333
340
  use_cross_only=(not self.use_self_and_cross),
334
- only_pure_attention=self.only_pure_attention
341
+ only_pure_attention=self.only_pure_attention,
342
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
343
+ kernel_init=self.kernel_init
335
344
  )(projected_x, context)
336
345
 
337
346
  if self.use_projection == True:
338
347
  if self.use_linear_attention:
339
348
  projected_x = nn.Dense(features=C, precision=self.precision,
340
349
  dtype=self.dtype, use_bias=False,
341
- kernel_init=kernel_init(1.0),
350
+ kernel_init=self.kernel_init(),
342
351
  name=f'project_out')(projected_x)
343
352
  else:
344
353
  projected_x = nn.Conv(
345
354
  features=C, kernel_size=(1, 1),
346
- kernel_init=kernel_init(1.0),
355
+ kernel_init=self.kernel_init(),
347
356
  strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
348
357
  precision=self.precision, name=f'project_out_conv',
349
358
  )(projected_x)
@@ -267,15 +267,17 @@ class ResidualBlock(nn.Module):
267
267
  kernel_init:Callable=kernel_init(1.0)
268
268
  dtype: Optional[Dtype] = None
269
269
  precision: PrecisionLike = None
270
+ named_norms:bool=False
270
271
 
271
272
  def setup(self):
272
273
  if self.norm_groups > 0:
273
274
  norm = partial(nn.GroupNorm, self.norm_groups)
275
+ self.norm1 = norm(name="GroupNorm_0") if self.named_norms else norm()
276
+ self.norm2 = norm(name="GroupNorm_1") if self.named_norms else norm()
274
277
  else:
275
278
  norm = partial(nn.RMSNorm, 1e-5)
276
-
277
- self.norm1 = norm()
278
- self.norm2 = norm()
279
+ self.norm1 = norm()
280
+ self.norm2 = norm()
279
281
 
280
282
  @nn.compact
281
283
  def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
@@ -19,15 +19,16 @@ class Unet(nn.Module):
19
19
  norm_groups:int=8
20
20
  dtype: Optional[Dtype] = None
21
21
  precision: PrecisionLike = None
22
+ named_norms: bool = False # This is for backward compatibility reasons; older checkpoints have named norms
23
+ kernel_init: Callable = partial(kernel_init, dtype=jnp.float32)
22
24
 
23
25
  def setup(self):
24
26
  if self.norm_groups > 0:
25
27
  norm = partial(nn.GroupNorm, self.norm_groups)
28
+ self.conv_out_norm = norm(name="GroupNorm_0") if self.named_norms else norm()
26
29
  else:
27
30
  norm = partial(nn.RMSNorm, 1e-5)
28
-
29
- # self.last_up_norm = norm()
30
- self.conv_out_norm = norm()
31
+ self.conv_out_norm = norm()
31
32
 
32
33
  @nn.compact
33
34
  def __call__(self, x, temb, textcontext):
@@ -49,7 +50,7 @@ class Unet(nn.Module):
49
50
  features=self.feature_depths[0],
50
51
  kernel_size=(3, 3),
51
52
  strides=(1, 1),
52
- kernel_init=kernel_init(1.0),
53
+ kernel_init=self.kernel_init(1.0),
53
54
  dtype=self.dtype,
54
55
  precision=self.precision
55
56
  )(x)
@@ -64,13 +65,14 @@ class Unet(nn.Module):
64
65
  down_conv_type,
65
66
  name=f"down_{i}_residual_{j}",
66
67
  features=dim_in,
67
- kernel_init=kernel_init(1.0),
68
+ kernel_init=self.kernel_init(1.0),
68
69
  kernel_size=(3, 3),
69
70
  strides=(1, 1),
70
71
  activation=self.activation,
71
72
  norm_groups=self.norm_groups,
72
73
  dtype=self.dtype,
73
- precision=self.precision
74
+ precision=self.precision,
75
+ named_norms=self.named_norms
74
76
  )(x, temb)
75
77
  if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
76
78
  x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
@@ -80,6 +82,8 @@ class Unet(nn.Module):
80
82
  use_self_and_cross=attention_config.get("use_self_and_cross", True),
81
83
  precision=attention_config.get("precision", self.precision),
82
84
  only_pure_attention=attention_config.get("only_pure_attention", True),
85
+ force_fp32_for_softmax=attention_config.get("force_fp32_for_softmax", False),
86
+ kernel_init=self.kernel_init(1.0),
83
87
  name=f"down_{i}_attention_{j}")(x, textcontext)
84
88
  # print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
85
89
  downs.append(x)
@@ -102,13 +106,14 @@ class Unet(nn.Module):
102
106
  middle_conv_type,
103
107
  name=f"middle_res1_{j}",
104
108
  features=middle_dim_out,
105
- kernel_init=kernel_init(1.0),
109
+ kernel_init=self.kernel_init(1.0),
106
110
  kernel_size=(3, 3),
107
111
  strides=(1, 1),
108
112
  activation=self.activation,
109
113
  norm_groups=self.norm_groups,
110
114
  dtype=self.dtype,
111
- precision=self.precision
115
+ precision=self.precision,
116
+ named_norms=self.named_norms
112
117
  )(x, temb)
113
118
  if middle_attention is not None and j == self.num_middle_res_blocks - 1: # Apply attention only on the last block
114
119
  x = TransformerBlock(heads=middle_attention['heads'], dtype=middle_attention.get('dtype', jnp.float32),
@@ -119,18 +124,21 @@ class Unet(nn.Module):
119
124
  use_self_and_cross=False,
120
125
  precision=middle_attention.get("precision", self.precision),
121
126
  only_pure_attention=middle_attention.get("only_pure_attention", True),
127
+ force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
128
+ kernel_init=self.kernel_init(1.0),
122
129
  name=f"middle_attention_{j}")(x, textcontext)
123
130
  x = ResidualBlock(
124
131
  middle_conv_type,
125
132
  name=f"middle_res2_{j}",
126
133
  features=middle_dim_out,
127
- kernel_init=kernel_init(1.0),
134
+ kernel_init=self.kernel_init(1.0),
128
135
  kernel_size=(3, 3),
129
136
  strides=(1, 1),
130
137
  activation=self.activation,
131
138
  norm_groups=self.norm_groups,
132
139
  dtype=self.dtype,
133
- precision=self.precision
140
+ precision=self.precision,
141
+ named_norms=self.named_norms
134
142
  )(x, temb)
135
143
 
136
144
  # Upscaling Blocks
@@ -145,13 +153,14 @@ class Unet(nn.Module):
145
153
  up_conv_type,# if j == 0 else "separable",
146
154
  name=f"up_{i}_residual_{j}",
147
155
  features=dim_out,
148
- kernel_init=kernel_init(1.0),
156
+ kernel_init=self.kernel_init(1.0),
149
157
  kernel_size=kernel_size,
150
158
  strides=(1, 1),
151
159
  activation=self.activation,
152
160
  norm_groups=self.norm_groups,
153
161
  dtype=self.dtype,
154
- precision=self.precision
162
+ precision=self.precision,
163
+ named_norms=self.named_norms
155
164
  )(x, temb)
156
165
  if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
157
166
  x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
@@ -161,6 +170,8 @@ class Unet(nn.Module):
161
170
  use_self_and_cross=attention_config.get("use_self_and_cross", True),
162
171
  precision=attention_config.get("precision", self.precision),
163
172
  only_pure_attention=attention_config.get("only_pure_attention", True),
173
+ force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
174
+ kernel_init=self.kernel_init(1.0),
164
175
  name=f"up_{i}_attention_{j}")(x, textcontext)
165
176
  # print("Upscaling ", i, x.shape)
166
177
  if i != len(feature_depths) - 1:
@@ -179,7 +190,7 @@ class Unet(nn.Module):
179
190
  features=self.feature_depths[0],
180
191
  kernel_size=(3, 3),
181
192
  strides=(1, 1),
182
- kernel_init=kernel_init(1.0),
193
+ kernel_init=self.kernel_init(1.0),
183
194
  dtype=self.dtype,
184
195
  precision=self.precision
185
196
  )(x)
@@ -190,13 +201,14 @@ class Unet(nn.Module):
190
201
  conv_type,
191
202
  name="final_residual",
192
203
  features=self.feature_depths[0],
193
- kernel_init=kernel_init(1.0),
204
+ kernel_init=self.kernel_init(1.0),
194
205
  kernel_size=(3,3),
195
206
  strides=(1, 1),
196
207
  activation=self.activation,
197
208
  norm_groups=self.norm_groups,
198
209
  dtype=self.dtype,
199
- precision=self.precision
210
+ precision=self.precision,
211
+ named_norms=self.named_norms
200
212
  )(x, temb)
201
213
 
202
214
  x = self.conv_out_norm(x)
@@ -208,7 +220,7 @@ class Unet(nn.Module):
208
220
  kernel_size=(3, 3),
209
221
  strides=(1, 1),
210
222
  # activation=jax.nn.mish
211
- kernel_init=kernel_init(0.0),
223
+ kernel_init=self.kernel_init(0.0),
212
224
  dtype=self.dtype,
213
225
  precision=self.precision
214
226
  )(x)
@@ -16,6 +16,7 @@ from flaxdiff.utils import RandomMarkovState
16
16
  from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
17
17
 
18
18
  from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
19
+ from flax.training.dynamic_scale import DynamicScale
19
20
 
20
21
  class TrainState(SimpleTrainState):
21
22
  rngs: jax.random.PRNGKey
@@ -83,7 +84,8 @@ class DiffusionTrainer(SimpleTrainer):
83
84
  new_state = existing_state
84
85
 
85
86
  if param_transforms is not None:
86
- params = param_transforms(params)
87
+ new_state['params'] = param_transforms(new_state['params'])
88
+ new_state['ema_params'] = param_transforms(new_state['ema_params'])
87
89
 
88
90
  state = TrainState.create(
89
91
  apply_fn=model.apply,
@@ -92,7 +94,7 @@ class DiffusionTrainer(SimpleTrainer):
92
94
  tx=optimizer,
93
95
  rngs=rngs,
94
96
  metrics=Metrics.empty(),
95
- dynamic_scale = flax.training.dynamic_scale.DynamicScale() if use_dynamic_scale else None
97
+ dynamic_scale = DynamicScale() if use_dynamic_scale else None
96
98
  )
97
99
 
98
100
  if existing_best_state is not None:
@@ -22,7 +22,7 @@ from jax.experimental.shard_map import shard_map
22
22
  from orbax.checkpoint.utils import fully_replicated_host_local_array_to_global_array
23
23
  from termcolor import colored
24
24
  from typing import Dict, Callable, Sequence, Any, Union, Tuple
25
-
25
+ from flax.training.dynamic_scale import DynamicScale
26
26
  from flaxdiff.utils import RandomMarkovState
27
27
 
28
28
  PROCESS_COLOR_MAP = {
@@ -68,7 +68,7 @@ class Metrics(metrics.Collection):
68
68
  # Define the TrainState
69
69
  class SimpleTrainState(train_state.TrainState):
70
70
  metrics: Metrics
71
- dynamic_scale: flax.training.dynamic_scale.DynamicScale
71
+ dynamic_scale: DynamicScale
72
72
 
73
73
  class SimpleTrainer:
74
74
  state: SimpleTrainState
@@ -177,13 +177,16 @@ class SimpleTrainer:
177
177
  params = model.init(subkey, **input_vars)
178
178
  else:
179
179
  params = existing_state['params']
180
+
181
+ if param_transforms is not None:
182
+ params = param_transforms(params)
180
183
 
181
184
  state = SimpleTrainState.create(
182
185
  apply_fn=model.apply,
183
186
  params=params,
184
187
  tx=optimizer,
185
188
  metrics=Metrics.empty(),
186
- dynamic_scale = flax.training.dynamic_scale.DynamicScale() if use_dynamic_scale else None
189
+ dynamic_scale = DynamicScale() if use_dynamic_scale else None
187
190
  )
188
191
  if existing_best_state is not None:
189
192
  best_state = state.replace(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.9
3
+ Version: 0.1.11
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -11,7 +11,7 @@ required_packages=[
11
11
  setup(
12
12
  name='flaxdiff',
13
13
  packages=find_packages(),
14
- version='0.1.9',
14
+ version='0.1.11',
15
15
  description='A versatile and easy to understand Diffusion library',
16
16
  long_description=open('README.md').read(),
17
17
  long_description_content_type='text/markdown',
File without changes
File without changes
File without changes
File without changes