flaxdiff 0.1.9__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flaxdiff/models/common.py CHANGED
@@ -267,15 +267,17 @@ class ResidualBlock(nn.Module):
267
267
  kernel_init:Callable=kernel_init(1.0)
268
268
  dtype: Optional[Dtype] = None
269
269
  precision: PrecisionLike = None
270
+ named_norms:bool=False
270
271
 
271
272
  def setup(self):
272
273
  if self.norm_groups > 0:
273
274
  norm = partial(nn.GroupNorm, self.norm_groups)
275
+ self.norm1 = norm(name="GroupNorm_0") if self.named_norms else norm()
276
+ self.norm2 = norm(name="GroupNorm_1") if self.named_norms else norm()
274
277
  else:
275
278
  norm = partial(nn.RMSNorm, 1e-5)
276
-
277
- self.norm1 = norm()
278
- self.norm2 = norm()
279
+ self.norm1 = norm()
280
+ self.norm2 = norm()
279
281
 
280
282
  @nn.compact
281
283
  def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
@@ -19,15 +19,15 @@ class Unet(nn.Module):
19
19
  norm_groups:int=8
20
20
  dtype: Optional[Dtype] = None
21
21
  precision: PrecisionLike = None
22
+ named_norms: bool = False # This is for backward compatibility reasons; older checkpoints have named norms
22
23
 
23
24
  def setup(self):
24
25
  if self.norm_groups > 0:
25
26
  norm = partial(nn.GroupNorm, self.norm_groups)
27
+ self.conv_out_norm = norm(name="GroupNorm_0") if self.named_norms else norm()
26
28
  else:
27
29
  norm = partial(nn.RMSNorm, 1e-5)
28
-
29
- # self.last_up_norm = norm()
30
- self.conv_out_norm = norm()
30
+ self.conv_out_norm = norm()
31
31
 
32
32
  @nn.compact
33
33
  def __call__(self, x, temb, textcontext):
@@ -70,7 +70,8 @@ class Unet(nn.Module):
70
70
  activation=self.activation,
71
71
  norm_groups=self.norm_groups,
72
72
  dtype=self.dtype,
73
- precision=self.precision
73
+ precision=self.precision,
74
+ named_norms=self.named_norms
74
75
  )(x, temb)
75
76
  if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
76
77
  x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
@@ -108,7 +109,8 @@ class Unet(nn.Module):
108
109
  activation=self.activation,
109
110
  norm_groups=self.norm_groups,
110
111
  dtype=self.dtype,
111
- precision=self.precision
112
+ precision=self.precision,
113
+ named_norms=self.named_norms
112
114
  )(x, temb)
113
115
  if middle_attention is not None and j == self.num_middle_res_blocks - 1: # Apply attention only on the last block
114
116
  x = TransformerBlock(heads=middle_attention['heads'], dtype=middle_attention.get('dtype', jnp.float32),
@@ -130,7 +132,8 @@ class Unet(nn.Module):
130
132
  activation=self.activation,
131
133
  norm_groups=self.norm_groups,
132
134
  dtype=self.dtype,
133
- precision=self.precision
135
+ precision=self.precision,
136
+ named_norms=self.named_norms
134
137
  )(x, temb)
135
138
 
136
139
  # Upscaling Blocks
@@ -151,7 +154,8 @@ class Unet(nn.Module):
151
154
  activation=self.activation,
152
155
  norm_groups=self.norm_groups,
153
156
  dtype=self.dtype,
154
- precision=self.precision
157
+ precision=self.precision,
158
+ named_norms=self.named_norms
155
159
  )(x, temb)
156
160
  if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
157
161
  x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
@@ -196,7 +200,8 @@ class Unet(nn.Module):
196
200
  activation=self.activation,
197
201
  norm_groups=self.norm_groups,
198
202
  dtype=self.dtype,
199
- precision=self.precision
203
+ precision=self.precision,
204
+ named_norms=self.named_norms
200
205
  )(x, temb)
201
206
 
202
207
  x = self.conv_out_norm(x)
@@ -16,6 +16,7 @@ from flaxdiff.utils import RandomMarkovState
16
16
  from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
17
17
 
18
18
  from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
19
+ from flax.training.dynamic_scale import DynamicScale
19
20
 
20
21
  class TrainState(SimpleTrainState):
21
22
  rngs: jax.random.PRNGKey
@@ -83,7 +84,8 @@ class DiffusionTrainer(SimpleTrainer):
83
84
  new_state = existing_state
84
85
 
85
86
  if param_transforms is not None:
86
- params = param_transforms(params)
87
+ new_state['params'] = param_transforms(new_state['params'])
88
+ new_state['ema_params'] = param_transforms(new_state['ema_params'])
87
89
 
88
90
  state = TrainState.create(
89
91
  apply_fn=model.apply,
@@ -92,7 +94,7 @@ class DiffusionTrainer(SimpleTrainer):
92
94
  tx=optimizer,
93
95
  rngs=rngs,
94
96
  metrics=Metrics.empty(),
95
- dynamic_scale = flax.training.dynamic_scale.DynamicScale() if use_dynamic_scale else None
97
+ dynamic_scale = DynamicScale() if use_dynamic_scale else None
96
98
  )
97
99
 
98
100
  if existing_best_state is not None:
@@ -22,7 +22,7 @@ from jax.experimental.shard_map import shard_map
22
22
  from orbax.checkpoint.utils import fully_replicated_host_local_array_to_global_array
23
23
  from termcolor import colored
24
24
  from typing import Dict, Callable, Sequence, Any, Union, Tuple
25
-
25
+ from flax.training.dynamic_scale import DynamicScale
26
26
  from flaxdiff.utils import RandomMarkovState
27
27
 
28
28
  PROCESS_COLOR_MAP = {
@@ -68,7 +68,7 @@ class Metrics(metrics.Collection):
68
68
  # Define the TrainState
69
69
  class SimpleTrainState(train_state.TrainState):
70
70
  metrics: Metrics
71
- dynamic_scale: flax.training.dynamic_scale.DynamicScale
71
+ dynamic_scale: DynamicScale
72
72
 
73
73
  class SimpleTrainer:
74
74
  state: SimpleTrainState
@@ -177,13 +177,16 @@ class SimpleTrainer:
177
177
  params = model.init(subkey, **input_vars)
178
178
  else:
179
179
  params = existing_state['params']
180
+
181
+ if param_transforms is not None:
182
+ params = param_transforms(params)
180
183
 
181
184
  state = SimpleTrainState.create(
182
185
  apply_fn=model.apply,
183
186
  params=params,
184
187
  tx=optimizer,
185
188
  metrics=Metrics.empty(),
186
- dynamic_scale = flax.training.dynamic_scale.DynamicScale() if use_dynamic_scale else None
189
+ dynamic_scale = DynamicScale() if use_dynamic_scale else None
187
190
  )
188
191
  if existing_best_state is not None:
189
192
  best_state = state.replace(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.9
3
+ Version: 0.1.10
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -2,9 +2,9 @@ flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
3
3
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
4
4
  flaxdiff/models/attention.py,sha256=YyVI3dTAMB8cS8VWHgtIigr2YY-MYfFTlaNDfjNJOCk,12596
5
- flaxdiff/models/common.py,sha256=nh32GIfgT_vVab4DEFiRAns5WGKbv6L5xNhzzfKKyBs,10590
5
+ flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
6
6
  flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
7
- flaxdiff/models/simple_unet.py,sha256=_elSWNaB3EG-DwnrdIPVPF4OkU0xaa2IJk6OVITOwWM,9691
7
+ flaxdiff/models/simple_unet.py,sha256=H67Pfy8BqKHvhdw_K3lBiFdruNQFBMElw8SDZdvg9Ec,10084
8
8
  flaxdiff/models/simple_vit.py,sha256=xD23i1b7WEvoH4tUMsLyCe9ebDcv-PpaV0Nso38Jlb8,3887
9
9
  flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
10
10
  flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
@@ -30,9 +30,9 @@ flaxdiff/schedulers/linear.py,sha256=6003F5ISq1Wc0h6UAzY95MJgsDIKGMhBzbiVALpea0k
30
30
  flaxdiff/schedulers/sqrt.py,sha256=1F84ZgQPuoNMhe6yxGTR2G0h7dPOZtm4UDQOakbSsEU,445
31
31
  flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,128
32
32
  flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
33
- flaxdiff/trainer/diffusion_trainer.py,sha256=z-ERdPt8mB6drXXlLjbGpbPreDIQlGmJFPRJhaoEZ1M,9242
34
- flaxdiff/trainer/simple_trainer.py,sha256=Dv2F7e2PQS_2b972iRr66odCcPPdJ9cZAD5t9LguOiw,18110
35
- flaxdiff-0.1.9.dist-info/METADATA,sha256=HhZlM5rBZrOSpNhS8KpeBCoXSmbsHy8ZAKY7gj10P0c,22082
36
- flaxdiff-0.1.9.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
37
- flaxdiff-0.1.9.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
38
- flaxdiff-0.1.9.dist-info/RECORD,,
33
+ flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
34
+ flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
35
+ flaxdiff-0.1.10.dist-info/METADATA,sha256=q9O56jlhtuznnbmlHeKa9-gLFtWXge0bwBU6g9_P8Jk,22083
36
+ flaxdiff-0.1.10.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
37
+ flaxdiff-0.1.10.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
38
+ flaxdiff-0.1.10.dist-info/RECORD,,