flaxdiff 0.1.25__tar.gz → 0.1.27__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/PKG-INFO +1 -1
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/models/simple_vit.py +16 -6
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff.egg-info/PKG-INFO +1 -1
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/setup.py +1 -1
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/README.md +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/__init__.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/data/__init__.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/data/online_loader.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/models/__init__.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/models/attention.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/models/autoencoder/__init__.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/models/autoencoder/diffusers.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/models/common.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/models/favor_fastattn.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/models/simple_unet.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/predictors/__init__.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/samplers/__init__.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/samplers/common.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/samplers/ddim.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/samplers/ddpm.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/samplers/euler.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/samplers/heun_sampler.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/samplers/multistep_dpm.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/samplers/rk4_sampler.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/schedulers/__init__.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/schedulers/common.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/schedulers/continuous.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/schedulers/cosine.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/schedulers/discrete.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/schedulers/exp.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/schedulers/karras.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/schedulers/linear.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/schedulers/sqrt.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/trainer/__init__.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/trainer/diffusion_trainer.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/trainer/simple_trainer.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff/utils.py +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff.egg-info/SOURCES.txt +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff.egg-info/dependency_links.txt +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff.egg-info/requires.txt +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/flaxdiff.egg-info/top_level.txt +0 -0
- {flaxdiff-0.1.25 → flaxdiff-0.1.27}/setup.cfg +0 -0
@@ -23,6 +23,7 @@ class PatchEmbedding(nn.Module):
|
|
23
23
|
embedding_dim: int
|
24
24
|
dtype: Any = jnp.float32
|
25
25
|
precision: Any = jax.lax.Precision.HIGH
|
26
|
+
kernel_init: Callable = partial(kernel_init, 1.0)
|
26
27
|
|
27
28
|
@nn.compact
|
28
29
|
def __call__(self, x):
|
@@ -33,6 +34,7 @@ class PatchEmbedding(nn.Module):
|
|
33
34
|
kernel_size=(self.patch_size, self.patch_size),
|
34
35
|
strides=(self.patch_size, self.patch_size),
|
35
36
|
dtype=self.dtype,
|
37
|
+
kernel_init=self.kernel_init(),
|
36
38
|
precision=self.precision)(x)
|
37
39
|
x = jnp.reshape(x, (batch, -1, self.embedding_dim))
|
38
40
|
return x
|
@@ -81,7 +83,7 @@ class UViT(nn.Module):
|
|
81
83
|
|
82
84
|
# Patch embedding
|
83
85
|
x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features,
|
84
|
-
dtype=self.dtype, precision=self.precision)(x)
|
86
|
+
dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init)(x)
|
85
87
|
num_patches = x.shape[1]
|
86
88
|
|
87
89
|
context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
|
@@ -119,14 +121,14 @@ class UViT(nn.Module):
|
|
119
121
|
|
120
122
|
# # Out blocks
|
121
123
|
for i in range(self.num_layers // 2):
|
122
|
-
|
123
|
-
|
124
|
-
dtype=self.dtype, precision=self.precision)(
|
124
|
+
x = jnp.concatenate([x, skips.pop()], axis=-1)
|
125
|
+
x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
|
126
|
+
dtype=self.dtype, precision=self.precision)(x)
|
125
127
|
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
|
126
128
|
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
127
129
|
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
128
130
|
only_pure_attention=False,
|
129
|
-
kernel_init=self.kernel_init())(
|
131
|
+
kernel_init=self.kernel_init())(x)
|
130
132
|
|
131
133
|
# print(f'Shape of x after transformer blocks: {x.shape}')
|
132
134
|
x = self.norm()(x)
|
@@ -139,6 +141,14 @@ class UViT(nn.Module):
|
|
139
141
|
x = x[:, 1 + num_text_tokens:, :]
|
140
142
|
x = unpatchify(x, channels=self.output_channels)
|
141
143
|
# print(f'Shape of x after final dense layer: {x.shape}')
|
142
|
-
x = nn.
|
144
|
+
x = nn.Conv(
|
145
|
+
features=self.output_channels,
|
146
|
+
kernel_size=(3, 3),
|
147
|
+
strides=(1, 1),
|
148
|
+
padding='SAME',
|
149
|
+
dtype=self.dtype,
|
150
|
+
precision=self.precision,
|
151
|
+
kernel_init=kernel_init(0.0),
|
152
|
+
)(x)
|
143
153
|
|
144
154
|
return x
|
@@ -11,7 +11,7 @@ required_packages=[
|
|
11
11
|
setup(
|
12
12
|
name='flaxdiff',
|
13
13
|
packages=find_packages(),
|
14
|
-
version='0.1.
|
14
|
+
version='0.1.27',
|
15
15
|
description='A versatile and easy to understand Diffusion library',
|
16
16
|
long_description=open('README.md').read(),
|
17
17
|
long_description_content_type='text/markdown',
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|