flaxdiff 0.1.7__tar.gz → 0.1.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/PKG-INFO +1 -1
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/models/attention.py +1 -60
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/models/common.py +4 -4
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/models/simple_unet.py +3 -3
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff.egg-info/PKG-INFO +1 -1
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/setup.py +1 -1
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/README.md +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/__init__.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/models/__init__.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/models/autoencoder/__init__.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/models/autoencoder/diffusers.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/models/favor_fastattn.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/models/simple_vit.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/predictors/__init__.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/samplers/__init__.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/samplers/common.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/samplers/ddim.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/samplers/ddpm.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/samplers/euler.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/samplers/heun_sampler.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/samplers/multistep_dpm.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/samplers/rk4_sampler.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/schedulers/__init__.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/schedulers/common.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/schedulers/continuous.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/schedulers/cosine.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/schedulers/discrete.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/schedulers/exp.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/schedulers/karras.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/schedulers/linear.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/schedulers/sqrt.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/trainer/__init__.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/trainer/diffusion_trainer.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/trainer/simple_trainer.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff/utils.py +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff.egg-info/SOURCES.txt +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff.egg-info/dependency_links.txt +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff.egg-info/requires.txt +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/flaxdiff.egg-info/top_level.txt +0 -0
- {flaxdiff-0.1.7 → flaxdiff-0.1.8}/setup.cfg +0 -0
@@ -162,65 +162,6 @@ class NormalAttention(nn.Module):
|
|
162
162
|
proj = proj.reshape(orig_x_shape)
|
163
163
|
return proj
|
164
164
|
|
165
|
-
class BasicTransformerBlock(nn.Module):
|
166
|
-
# Has self and cross attention
|
167
|
-
query_dim: int
|
168
|
-
heads: int = 4
|
169
|
-
dim_head: int = 64
|
170
|
-
dtype: Optional[Dtype] = None
|
171
|
-
precision: PrecisionLike = None
|
172
|
-
use_bias: bool = True
|
173
|
-
kernel_init: Callable = lambda : kernel_init(1.0)
|
174
|
-
use_flash_attention:bool = False
|
175
|
-
use_cross_only:bool = False
|
176
|
-
|
177
|
-
def setup(self):
|
178
|
-
if self.use_flash_attention:
|
179
|
-
attenBlock = EfficientAttention
|
180
|
-
else:
|
181
|
-
attenBlock = NormalAttention
|
182
|
-
|
183
|
-
self.attention1 = attenBlock(
|
184
|
-
query_dim=self.query_dim,
|
185
|
-
heads=self.heads,
|
186
|
-
dim_head=self.dim_head,
|
187
|
-
name=f'Attention1',
|
188
|
-
precision=self.precision,
|
189
|
-
use_bias=self.use_bias,
|
190
|
-
dtype=self.dtype,
|
191
|
-
kernel_init=self.kernel_init
|
192
|
-
)
|
193
|
-
self.attention2 = attenBlock(
|
194
|
-
query_dim=self.query_dim,
|
195
|
-
heads=self.heads,
|
196
|
-
dim_head=self.dim_head,
|
197
|
-
name=f'Attention2',
|
198
|
-
precision=self.precision,
|
199
|
-
use_bias=self.use_bias,
|
200
|
-
dtype=self.dtype,
|
201
|
-
kernel_init=self.kernel_init
|
202
|
-
)
|
203
|
-
|
204
|
-
self.ff = FlaxFeedForward(dim=self.query_dim)
|
205
|
-
self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
206
|
-
self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
207
|
-
self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
208
|
-
|
209
|
-
@nn.compact
|
210
|
-
def __call__(self, hidden_states, context=None):
|
211
|
-
# self attention
|
212
|
-
if not self.use_cross_only:
|
213
|
-
print("Using self attention")
|
214
|
-
hidden_states = hidden_states + self.attention1(self.norm1(hidden_states))
|
215
|
-
|
216
|
-
# cross attention
|
217
|
-
hidden_states = hidden_states + self.attention2(self.norm2(hidden_states), context)
|
218
|
-
|
219
|
-
# feed forward
|
220
|
-
hidden_states = hidden_states + self.ff(self.norm3(hidden_states))
|
221
|
-
|
222
|
-
return hidden_states
|
223
|
-
|
224
165
|
class FlaxGEGLU(nn.Module):
|
225
166
|
r"""
|
226
167
|
Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
|
@@ -330,7 +271,7 @@ class BasicTransformerBlock(nn.Module):
|
|
330
271
|
@nn.compact
|
331
272
|
def __call__(self, hidden_states, context=None):
|
332
273
|
if self.only_pure_attention:
|
333
|
-
return self.attention2(
|
274
|
+
return self.attention2(hidden_states, context)
|
334
275
|
|
335
276
|
# self attention
|
336
277
|
if not self.use_cross_only:
|
@@ -270,8 +270,8 @@ class ResidualBlock(nn.Module):
|
|
270
270
|
@nn.compact
|
271
271
|
def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
|
272
272
|
residual = x
|
273
|
-
|
274
|
-
out = nn.RMSNorm()(x)
|
273
|
+
out = nn.GroupNorm(self.norm_groups)(x)
|
274
|
+
# out = nn.RMSNorm()(x)
|
275
275
|
out = self.activation(out)
|
276
276
|
|
277
277
|
out = ConvLayer(
|
@@ -295,8 +295,8 @@ class ResidualBlock(nn.Module):
|
|
295
295
|
# out = out * (1 + scale) + shift
|
296
296
|
out = out + temb
|
297
297
|
|
298
|
-
|
299
|
-
out = nn.RMSNorm()(out)
|
298
|
+
out = nn.GroupNorm(self.norm_groups)(out)
|
299
|
+
# out = nn.RMSNorm()(out)
|
300
300
|
out = self.activation(out)
|
301
301
|
|
302
302
|
out = ConvLayer(
|
@@ -65,7 +65,7 @@ class Unet(nn.Module):
|
|
65
65
|
if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
|
66
66
|
x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
|
67
67
|
dim_head=dim_in // attention_config['heads'],
|
68
|
-
use_flash_attention=attention_config.get("flash_attention",
|
68
|
+
use_flash_attention=attention_config.get("flash_attention", False),
|
69
69
|
use_projection=attention_config.get("use_projection", False),
|
70
70
|
use_self_and_cross=attention_config.get("use_self_and_cross", True),
|
71
71
|
precision=attention_config.get("precision", self.precision),
|
@@ -103,7 +103,7 @@ class Unet(nn.Module):
|
|
103
103
|
if middle_attention is not None and j == self.num_middle_res_blocks - 1: # Apply attention only on the last block
|
104
104
|
x = TransformerBlock(heads=middle_attention['heads'], dtype=middle_attention.get('dtype', jnp.float32),
|
105
105
|
dim_head=middle_dim_out // middle_attention['heads'],
|
106
|
-
use_flash_attention=middle_attention.get("flash_attention",
|
106
|
+
use_flash_attention=middle_attention.get("flash_attention", False),
|
107
107
|
use_linear_attention=False,
|
108
108
|
use_projection=middle_attention.get("use_projection", False),
|
109
109
|
use_self_and_cross=False,
|
@@ -146,7 +146,7 @@ class Unet(nn.Module):
|
|
146
146
|
if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
|
147
147
|
x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
|
148
148
|
dim_head=dim_out // attention_config['heads'],
|
149
|
-
use_flash_attention=attention_config.get("flash_attention",
|
149
|
+
use_flash_attention=attention_config.get("flash_attention", False),
|
150
150
|
use_projection=attention_config.get("use_projection", False),
|
151
151
|
use_self_and_cross=attention_config.get("use_self_and_cross", True),
|
152
152
|
precision=attention_config.get("precision", self.precision),
|
@@ -11,7 +11,7 @@ required_packages=[
|
|
11
11
|
setup(
|
12
12
|
name='flaxdiff',
|
13
13
|
packages=find_packages(),
|
14
|
-
version='0.1.
|
14
|
+
version='0.1.8',
|
15
15
|
description='A versatile and easy to understand Diffusion library',
|
16
16
|
long_description=open('README.md').read(),
|
17
17
|
long_description_content_type='text/markdown',
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|