flaxdiff 0.1.9__tar.gz → 0.1.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/PKG-INFO +1 -1
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/models/common.py +5 -3
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/models/simple_unet.py +13 -8
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/trainer/diffusion_trainer.py +4 -2
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/trainer/simple_trainer.py +6 -3
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff.egg-info/PKG-INFO +1 -1
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/setup.py +1 -1
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/README.md +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/__init__.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/models/__init__.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/models/attention.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/models/autoencoder/__init__.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/models/autoencoder/diffusers.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/models/favor_fastattn.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/models/simple_vit.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/predictors/__init__.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/samplers/__init__.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/samplers/common.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/samplers/ddim.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/samplers/ddpm.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/samplers/euler.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/samplers/heun_sampler.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/samplers/multistep_dpm.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/samplers/rk4_sampler.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/schedulers/__init__.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/schedulers/common.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/schedulers/continuous.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/schedulers/cosine.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/schedulers/discrete.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/schedulers/exp.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/schedulers/karras.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/schedulers/linear.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/schedulers/sqrt.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/trainer/__init__.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff/utils.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff.egg-info/SOURCES.txt +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff.egg-info/dependency_links.txt +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff.egg-info/requires.txt +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/flaxdiff.egg-info/top_level.txt +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.10}/setup.cfg +0 -0
@@ -267,15 +267,17 @@ class ResidualBlock(nn.Module):
|
|
267
267
|
kernel_init:Callable=kernel_init(1.0)
|
268
268
|
dtype: Optional[Dtype] = None
|
269
269
|
precision: PrecisionLike = None
|
270
|
+
named_norms:bool=False
|
270
271
|
|
271
272
|
def setup(self):
|
272
273
|
if self.norm_groups > 0:
|
273
274
|
norm = partial(nn.GroupNorm, self.norm_groups)
|
275
|
+
self.norm1 = norm(name="GroupNorm_0") if self.named_norms else norm()
|
276
|
+
self.norm2 = norm(name="GroupNorm_1") if self.named_norms else norm()
|
274
277
|
else:
|
275
278
|
norm = partial(nn.RMSNorm, 1e-5)
|
276
|
-
|
277
|
-
|
278
|
-
self.norm2 = norm()
|
279
|
+
self.norm1 = norm()
|
280
|
+
self.norm2 = norm()
|
279
281
|
|
280
282
|
@nn.compact
|
281
283
|
def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
|
@@ -19,15 +19,15 @@ class Unet(nn.Module):
|
|
19
19
|
norm_groups:int=8
|
20
20
|
dtype: Optional[Dtype] = None
|
21
21
|
precision: PrecisionLike = None
|
22
|
+
named_norms: bool = False # This is for backward compatibility reasons; older checkpoints have named norms
|
22
23
|
|
23
24
|
def setup(self):
|
24
25
|
if self.norm_groups > 0:
|
25
26
|
norm = partial(nn.GroupNorm, self.norm_groups)
|
27
|
+
self.conv_out_norm = norm(name="GroupNorm_0") if self.named_norms else norm()
|
26
28
|
else:
|
27
29
|
norm = partial(nn.RMSNorm, 1e-5)
|
28
|
-
|
29
|
-
# self.last_up_norm = norm()
|
30
|
-
self.conv_out_norm = norm()
|
30
|
+
self.conv_out_norm = norm()
|
31
31
|
|
32
32
|
@nn.compact
|
33
33
|
def __call__(self, x, temb, textcontext):
|
@@ -70,7 +70,8 @@ class Unet(nn.Module):
|
|
70
70
|
activation=self.activation,
|
71
71
|
norm_groups=self.norm_groups,
|
72
72
|
dtype=self.dtype,
|
73
|
-
precision=self.precision
|
73
|
+
precision=self.precision,
|
74
|
+
named_norms=self.named_norms
|
74
75
|
)(x, temb)
|
75
76
|
if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
|
76
77
|
x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
|
@@ -108,7 +109,8 @@ class Unet(nn.Module):
|
|
108
109
|
activation=self.activation,
|
109
110
|
norm_groups=self.norm_groups,
|
110
111
|
dtype=self.dtype,
|
111
|
-
precision=self.precision
|
112
|
+
precision=self.precision,
|
113
|
+
named_norms=self.named_norms
|
112
114
|
)(x, temb)
|
113
115
|
if middle_attention is not None and j == self.num_middle_res_blocks - 1: # Apply attention only on the last block
|
114
116
|
x = TransformerBlock(heads=middle_attention['heads'], dtype=middle_attention.get('dtype', jnp.float32),
|
@@ -130,7 +132,8 @@ class Unet(nn.Module):
|
|
130
132
|
activation=self.activation,
|
131
133
|
norm_groups=self.norm_groups,
|
132
134
|
dtype=self.dtype,
|
133
|
-
precision=self.precision
|
135
|
+
precision=self.precision,
|
136
|
+
named_norms=self.named_norms
|
134
137
|
)(x, temb)
|
135
138
|
|
136
139
|
# Upscaling Blocks
|
@@ -151,7 +154,8 @@ class Unet(nn.Module):
|
|
151
154
|
activation=self.activation,
|
152
155
|
norm_groups=self.norm_groups,
|
153
156
|
dtype=self.dtype,
|
154
|
-
precision=self.precision
|
157
|
+
precision=self.precision,
|
158
|
+
named_norms=self.named_norms
|
155
159
|
)(x, temb)
|
156
160
|
if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
|
157
161
|
x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
|
@@ -196,7 +200,8 @@ class Unet(nn.Module):
|
|
196
200
|
activation=self.activation,
|
197
201
|
norm_groups=self.norm_groups,
|
198
202
|
dtype=self.dtype,
|
199
|
-
precision=self.precision
|
203
|
+
precision=self.precision,
|
204
|
+
named_norms=self.named_norms
|
200
205
|
)(x, temb)
|
201
206
|
|
202
207
|
x = self.conv_out_norm(x)
|
@@ -16,6 +16,7 @@ from flaxdiff.utils import RandomMarkovState
|
|
16
16
|
from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
|
17
17
|
|
18
18
|
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
|
19
|
+
from flax.training.dynamic_scale import DynamicScale
|
19
20
|
|
20
21
|
class TrainState(SimpleTrainState):
|
21
22
|
rngs: jax.random.PRNGKey
|
@@ -83,7 +84,8 @@ class DiffusionTrainer(SimpleTrainer):
|
|
83
84
|
new_state = existing_state
|
84
85
|
|
85
86
|
if param_transforms is not None:
|
86
|
-
params = param_transforms(params)
|
87
|
+
new_state['params'] = param_transforms(new_state['params'])
|
88
|
+
new_state['ema_params'] = param_transforms(new_state['ema_params'])
|
87
89
|
|
88
90
|
state = TrainState.create(
|
89
91
|
apply_fn=model.apply,
|
@@ -92,7 +94,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
92
94
|
tx=optimizer,
|
93
95
|
rngs=rngs,
|
94
96
|
metrics=Metrics.empty(),
|
95
|
-
dynamic_scale =
|
97
|
+
dynamic_scale = DynamicScale() if use_dynamic_scale else None
|
96
98
|
)
|
97
99
|
|
98
100
|
if existing_best_state is not None:
|
@@ -22,7 +22,7 @@ from jax.experimental.shard_map import shard_map
|
|
22
22
|
from orbax.checkpoint.utils import fully_replicated_host_local_array_to_global_array
|
23
23
|
from termcolor import colored
|
24
24
|
from typing import Dict, Callable, Sequence, Any, Union, Tuple
|
25
|
-
|
25
|
+
from flax.training.dynamic_scale import DynamicScale
|
26
26
|
from flaxdiff.utils import RandomMarkovState
|
27
27
|
|
28
28
|
PROCESS_COLOR_MAP = {
|
@@ -68,7 +68,7 @@ class Metrics(metrics.Collection):
|
|
68
68
|
# Define the TrainState
|
69
69
|
class SimpleTrainState(train_state.TrainState):
|
70
70
|
metrics: Metrics
|
71
|
-
dynamic_scale:
|
71
|
+
dynamic_scale: DynamicScale
|
72
72
|
|
73
73
|
class SimpleTrainer:
|
74
74
|
state: SimpleTrainState
|
@@ -177,13 +177,16 @@ class SimpleTrainer:
|
|
177
177
|
params = model.init(subkey, **input_vars)
|
178
178
|
else:
|
179
179
|
params = existing_state['params']
|
180
|
+
|
181
|
+
if param_transforms is not None:
|
182
|
+
params = param_transforms(params)
|
180
183
|
|
181
184
|
state = SimpleTrainState.create(
|
182
185
|
apply_fn=model.apply,
|
183
186
|
params=params,
|
184
187
|
tx=optimizer,
|
185
188
|
metrics=Metrics.empty(),
|
186
|
-
dynamic_scale =
|
189
|
+
dynamic_scale = DynamicScale() if use_dynamic_scale else None
|
187
190
|
)
|
188
191
|
if existing_best_state is not None:
|
189
192
|
best_state = state.replace(
|
@@ -11,7 +11,7 @@ required_packages=[
|
|
11
11
|
setup(
|
12
12
|
name='flaxdiff',
|
13
13
|
packages=find_packages(),
|
14
|
-
version='0.1.
|
14
|
+
version='0.1.10',
|
15
15
|
description='A versatile and easy to understand Diffusion library',
|
16
16
|
long_description=open('README.md').read(),
|
17
17
|
long_description_content_type='text/markdown',
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|