flaxdiff 0.1.24__tar.gz → 0.1.26__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/PKG-INFO +1 -1
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/models/common.py +18 -18
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/models/simple_vit.py +18 -8
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff.egg-info/PKG-INFO +1 -1
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/setup.py +1 -1
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/README.md +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/__init__.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/data/__init__.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/data/online_loader.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/models/__init__.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/models/attention.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/models/autoencoder/__init__.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/models/autoencoder/diffusers.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/models/favor_fastattn.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/models/simple_unet.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/predictors/__init__.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/samplers/__init__.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/samplers/common.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/samplers/ddim.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/samplers/ddpm.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/samplers/euler.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/samplers/heun_sampler.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/samplers/multistep_dpm.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/samplers/rk4_sampler.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/schedulers/__init__.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/schedulers/common.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/schedulers/continuous.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/schedulers/cosine.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/schedulers/discrete.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/schedulers/exp.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/schedulers/karras.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/schedulers/linear.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/schedulers/sqrt.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/trainer/__init__.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/trainer/diffusion_trainer.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/trainer/simple_trainer.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/utils.py +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff.egg-info/SOURCES.txt +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff.egg-info/dependency_links.txt +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff.egg-info/requires.txt +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff.egg-info/top_level.txt +0 -0
- {flaxdiff-0.1.24 → flaxdiff-0.1.26}/setup.cfg +0 -0
@@ -108,13 +108,13 @@ class FourierEmbedding(nn.Module):
|
|
108
108
|
class TimeProjection(nn.Module):
|
109
109
|
features:int
|
110
110
|
activation:Callable=jax.nn.gelu
|
111
|
-
kernel_init:Callable=
|
111
|
+
kernel_init:Callable=kernel_init(1.0)
|
112
112
|
|
113
113
|
@nn.compact
|
114
114
|
def __call__(self, x):
|
115
|
-
x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init
|
115
|
+
x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init)(x)
|
116
116
|
x = self.activation(x)
|
117
|
-
x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init
|
117
|
+
x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init)(x)
|
118
118
|
x = self.activation(x)
|
119
119
|
return x
|
120
120
|
|
@@ -123,7 +123,7 @@ class SeparableConv(nn.Module):
|
|
123
123
|
kernel_size:tuple=(3, 3)
|
124
124
|
strides:tuple=(1, 1)
|
125
125
|
use_bias:bool=False
|
126
|
-
kernel_init:Callable=
|
126
|
+
kernel_init:Callable=kernel_init(1.0)
|
127
127
|
padding:str="SAME"
|
128
128
|
dtype: Optional[Dtype] = None
|
129
129
|
precision: PrecisionLike = None
|
@@ -133,7 +133,7 @@ class SeparableConv(nn.Module):
|
|
133
133
|
in_features = x.shape[-1]
|
134
134
|
depthwise = nn.Conv(
|
135
135
|
features=in_features, kernel_size=self.kernel_size,
|
136
|
-
strides=self.strides, kernel_init=self.kernel_init
|
136
|
+
strides=self.strides, kernel_init=self.kernel_init,
|
137
137
|
feature_group_count=in_features, use_bias=self.use_bias,
|
138
138
|
padding=self.padding,
|
139
139
|
dtype=self.dtype,
|
@@ -141,7 +141,7 @@ class SeparableConv(nn.Module):
|
|
141
141
|
)(x)
|
142
142
|
pointwise = nn.Conv(
|
143
143
|
features=self.features, kernel_size=(1, 1),
|
144
|
-
strides=(1, 1), kernel_init=self.kernel_init
|
144
|
+
strides=(1, 1), kernel_init=self.kernel_init,
|
145
145
|
use_bias=self.use_bias,
|
146
146
|
dtype=self.dtype,
|
147
147
|
precision=self.precision
|
@@ -153,7 +153,7 @@ class ConvLayer(nn.Module):
|
|
153
153
|
features:int
|
154
154
|
kernel_size:tuple=(3, 3)
|
155
155
|
strides:tuple=(1, 1)
|
156
|
-
kernel_init:Callable=
|
156
|
+
kernel_init:Callable=kernel_init(1.0)
|
157
157
|
dtype: Optional[Dtype] = None
|
158
158
|
precision: PrecisionLike = None
|
159
159
|
|
@@ -164,7 +164,7 @@ class ConvLayer(nn.Module):
|
|
164
164
|
features=self.features,
|
165
165
|
kernel_size=self.kernel_size,
|
166
166
|
strides=self.strides,
|
167
|
-
kernel_init=self.kernel_init
|
167
|
+
kernel_init=self.kernel_init,
|
168
168
|
dtype=self.dtype,
|
169
169
|
precision=self.precision
|
170
170
|
)
|
@@ -183,7 +183,7 @@ class ConvLayer(nn.Module):
|
|
183
183
|
features=self.features,
|
184
184
|
kernel_size=self.kernel_size,
|
185
185
|
strides=self.strides,
|
186
|
-
kernel_init=self.kernel_init
|
186
|
+
kernel_init=self.kernel_init,
|
187
187
|
dtype=self.dtype,
|
188
188
|
precision=self.precision
|
189
189
|
)
|
@@ -192,7 +192,7 @@ class ConvLayer(nn.Module):
|
|
192
192
|
features=self.features,
|
193
193
|
kernel_size=self.kernel_size,
|
194
194
|
strides=self.strides,
|
195
|
-
kernel_init=self.kernel_init
|
195
|
+
kernel_init=self.kernel_init,
|
196
196
|
dtype=self.dtype,
|
197
197
|
precision=self.precision
|
198
198
|
)
|
@@ -206,7 +206,7 @@ class Upsample(nn.Module):
|
|
206
206
|
activation:Callable=jax.nn.swish
|
207
207
|
dtype: Optional[Dtype] = None
|
208
208
|
precision: PrecisionLike = None
|
209
|
-
kernel_init:Callable=
|
209
|
+
kernel_init:Callable=kernel_init(1.0)
|
210
210
|
|
211
211
|
@nn.compact
|
212
212
|
def __call__(self, x, residual=None):
|
@@ -221,7 +221,7 @@ class Upsample(nn.Module):
|
|
221
221
|
strides=(1, 1),
|
222
222
|
dtype=self.dtype,
|
223
223
|
precision=self.precision,
|
224
|
-
kernel_init=self.kernel_init
|
224
|
+
kernel_init=self.kernel_init
|
225
225
|
)(out)
|
226
226
|
if residual is not None:
|
227
227
|
out = jnp.concatenate([out, residual], axis=-1)
|
@@ -233,7 +233,7 @@ class Downsample(nn.Module):
|
|
233
233
|
activation:Callable=jax.nn.swish
|
234
234
|
dtype: Optional[Dtype] = None
|
235
235
|
precision: PrecisionLike = None
|
236
|
-
kernel_init:Callable=
|
236
|
+
kernel_init:Callable=kernel_init(1.0)
|
237
237
|
|
238
238
|
@nn.compact
|
239
239
|
def __call__(self, x, residual=None):
|
@@ -244,7 +244,7 @@ class Downsample(nn.Module):
|
|
244
244
|
strides=(2, 2),
|
245
245
|
dtype=self.dtype,
|
246
246
|
precision=self.precision,
|
247
|
-
kernel_init=self.kernel_init
|
247
|
+
kernel_init=self.kernel_init
|
248
248
|
)(x)
|
249
249
|
if residual is not None:
|
250
250
|
if residual.shape[1] > out.shape[1]:
|
@@ -269,7 +269,7 @@ class ResidualBlock(nn.Module):
|
|
269
269
|
direction:str=None
|
270
270
|
res:int=2
|
271
271
|
norm_groups:int=8
|
272
|
-
kernel_init:Callable=
|
272
|
+
kernel_init:Callable=kernel_init(1.0)
|
273
273
|
dtype: Optional[Dtype] = None
|
274
274
|
precision: PrecisionLike = None
|
275
275
|
named_norms:bool=False
|
@@ -296,7 +296,7 @@ class ResidualBlock(nn.Module):
|
|
296
296
|
features=self.features,
|
297
297
|
kernel_size=self.kernel_size,
|
298
298
|
strides=self.strides,
|
299
|
-
kernel_init=self.kernel_init
|
299
|
+
kernel_init=self.kernel_init,
|
300
300
|
name="conv1",
|
301
301
|
dtype=self.dtype,
|
302
302
|
precision=self.precision
|
@@ -321,7 +321,7 @@ class ResidualBlock(nn.Module):
|
|
321
321
|
features=self.features,
|
322
322
|
kernel_size=self.kernel_size,
|
323
323
|
strides=self.strides,
|
324
|
-
kernel_init=self.kernel_init
|
324
|
+
kernel_init=self.kernel_init,
|
325
325
|
name="conv2",
|
326
326
|
dtype=self.dtype,
|
327
327
|
precision=self.precision
|
@@ -333,7 +333,7 @@ class ResidualBlock(nn.Module):
|
|
333
333
|
features=self.features,
|
334
334
|
kernel_size=(1, 1),
|
335
335
|
strides=1,
|
336
|
-
kernel_init=self.kernel_init
|
336
|
+
kernel_init=self.kernel_init,
|
337
337
|
name="residual_conv",
|
338
338
|
dtype=self.dtype,
|
339
339
|
precision=self.precision
|
@@ -23,6 +23,7 @@ class PatchEmbedding(nn.Module):
|
|
23
23
|
embedding_dim: int
|
24
24
|
dtype: Any = jnp.float32
|
25
25
|
precision: Any = jax.lax.Precision.HIGH
|
26
|
+
kernel_init: Callable = partial(kernel_init, 1.0)
|
26
27
|
|
27
28
|
@nn.compact
|
28
29
|
def __call__(self, x):
|
@@ -33,6 +34,7 @@ class PatchEmbedding(nn.Module):
|
|
33
34
|
kernel_size=(self.patch_size, self.patch_size),
|
34
35
|
strides=(self.patch_size, self.patch_size),
|
35
36
|
dtype=self.dtype,
|
37
|
+
kernel_init=self.kernel_init(),
|
36
38
|
precision=self.precision)(x)
|
37
39
|
x = jnp.reshape(x, (batch, -1, self.embedding_dim))
|
38
40
|
return x
|
@@ -96,7 +98,7 @@ class UViT(nn.Module):
|
|
96
98
|
# print(f'Shape of x after time embedding: {x.shape}')
|
97
99
|
|
98
100
|
# Add positional encoding
|
99
|
-
x = PositionalEncoding(max_len=x.shape[1], embedding_dim=self.emb_features)(x)
|
101
|
+
x = PositionalEncoding(max_len=x.shape[1], embedding_dim=self.emb_features, kernel_init=self.kernel_init)(x)
|
100
102
|
|
101
103
|
# print(f'Shape of x after positional encoding: {x.shape}')
|
102
104
|
|
@@ -113,20 +115,20 @@ class UViT(nn.Module):
|
|
113
115
|
# Middle block
|
114
116
|
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
|
115
117
|
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
116
|
-
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.
|
118
|
+
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
117
119
|
only_pure_attention=False,
|
118
120
|
kernel_init=self.kernel_init())(x)
|
119
121
|
|
120
122
|
# # Out blocks
|
121
123
|
for i in range(self.num_layers // 2):
|
122
|
-
|
123
|
-
|
124
|
-
dtype=self.dtype, precision=self.precision)(
|
124
|
+
x = jnp.concatenate([x, skips.pop()], axis=-1)
|
125
|
+
x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
|
126
|
+
dtype=self.dtype, precision=self.precision)(x)
|
125
127
|
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
|
126
128
|
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
127
|
-
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.
|
129
|
+
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
128
130
|
only_pure_attention=False,
|
129
|
-
kernel_init=self.kernel_init())(
|
131
|
+
kernel_init=self.kernel_init())(x)
|
130
132
|
|
131
133
|
# print(f'Shape of x after transformer blocks: {x.shape}')
|
132
134
|
x = self.norm()(x)
|
@@ -139,6 +141,14 @@ class UViT(nn.Module):
|
|
139
141
|
x = x[:, 1 + num_text_tokens:, :]
|
140
142
|
x = unpatchify(x, channels=self.output_channels)
|
141
143
|
# print(f'Shape of x after final dense layer: {x.shape}')
|
142
|
-
x = nn.
|
144
|
+
x = nn.Conv(
|
145
|
+
features=self.output_channels,
|
146
|
+
kernel_size=(3, 3),
|
147
|
+
strides=(1, 1),
|
148
|
+
padding='SAME',
|
149
|
+
dtype=self.dtype,
|
150
|
+
precision=self.precision,
|
151
|
+
kernel_init=kernel_init(0.0),
|
152
|
+
)(x)
|
143
153
|
|
144
154
|
return x
|
@@ -11,7 +11,7 @@ required_packages=[
|
|
11
11
|
setup(
|
12
12
|
name='flaxdiff',
|
13
13
|
packages=find_packages(),
|
14
|
-
version='0.1.
|
14
|
+
version='0.1.26',
|
15
15
|
description='A versatile and easy to understand Diffusion library',
|
16
16
|
long_description=open('README.md').read(),
|
17
17
|
long_description_content_type='text/markdown',
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|