flaxdiff 0.1.38__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/__init__.py +5 -1
- flaxdiff/data/benchmark_decord.py +443 -0
- flaxdiff/data/dataloaders.py +608 -0
- flaxdiff/data/dataset_map.py +61 -6
- flaxdiff/data/online_loader.py +779 -150
- flaxdiff/data/sources/audio_utils.py +142 -0
- flaxdiff/data/sources/av_example.py +125 -0
- flaxdiff/data/sources/av_utils.py +590 -0
- flaxdiff/data/sources/base.py +129 -0
- flaxdiff/data/sources/images.py +309 -0
- flaxdiff/data/sources/utils.py +158 -0
- flaxdiff/data/sources/videos.py +250 -0
- flaxdiff/data/sources/voxceleb2.py +412 -0
- flaxdiff/inference/__init__.py +0 -0
- flaxdiff/inference/pipeline.py +260 -0
- flaxdiff/inference/utils.py +320 -0
- flaxdiff/inputs/__init__.py +173 -0
- flaxdiff/inputs/encoders.py +98 -0
- flaxdiff/models/__init__.py +2 -1
- flaxdiff/models/attention.py +22 -16
- flaxdiff/models/autoencoder/autoencoder.py +141 -9
- flaxdiff/models/autoencoder/diffusers.py +88 -25
- flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
- flaxdiff/models/common.py +8 -18
- flaxdiff/models/simple_unet.py +6 -17
- flaxdiff/models/simple_vit.py +9 -13
- flaxdiff/models/unet_3d.py +446 -0
- flaxdiff/models/unet_3d_blocks.py +505 -0
- flaxdiff/samplers/common.py +358 -96
- flaxdiff/samplers/ddim.py +44 -5
- flaxdiff/schedulers/karras.py +20 -12
- flaxdiff/trainer/__init__.py +2 -1
- flaxdiff/trainer/autoencoder_trainer.py +1 -2
- flaxdiff/trainer/diffusion_trainer.py +35 -29
- flaxdiff/trainer/general_diffusion_trainer.py +583 -0
- flaxdiff/trainer/simple_trainer.py +51 -16
- flaxdiff/utils.py +128 -57
- {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
- flaxdiff-0.2.0.dist-info/RECORD +64 -0
- {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
- flaxdiff/data/datasets.py +0 -169
- flaxdiff/data/sources/gcs.py +0 -81
- flaxdiff/data/sources/tfds.py +0 -79
- flaxdiff/trainer/video_diffusion_trainer.py +0 -62
- flaxdiff-0.1.38.dist-info/RECORD +0 -50
- {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,9 @@ class StableDiffusionVAE(AutoEncoder):
|
|
22
22
|
dtype=dtype,
|
23
23
|
)
|
24
24
|
|
25
|
-
|
25
|
+
self.modelname = modelname
|
26
|
+
self.revision = revision
|
27
|
+
self.dtype = dtype
|
26
28
|
|
27
29
|
enc = FlaxEncoder(
|
28
30
|
in_channels=vae.config.in_channels,
|
@@ -63,29 +65,90 @@ class StableDiffusionVAE(AutoEncoder):
|
|
63
65
|
dtype=vae.dtype,
|
64
66
|
)
|
65
67
|
|
66
|
-
|
67
|
-
|
68
|
-
self.post_quant_conv = post_quant_conv
|
69
|
-
self.quant_conv = quant_conv
|
70
|
-
self.params = params
|
71
|
-
self.scaling_factor = vae.scaling_factor
|
68
|
+
scaling_factor = vae.scaling_factor
|
69
|
+
print(f"Scaling factor: {scaling_factor}")
|
72
70
|
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
71
|
+
def encode_single_frame(images, rngkey: jax.random.PRNGKey = None):
|
72
|
+
latents = enc.apply({"params": params['encoder']}, images, deterministic=True)
|
73
|
+
latents = quant_conv.apply({"params": params['quant_conv']}, latents)
|
74
|
+
if rngkey is not None:
|
75
|
+
mean, log_std = jnp.split(latents, 2, axis=-1)
|
76
|
+
log_std = jnp.clip(log_std, -30, 20)
|
77
|
+
std = jnp.exp(0.5 * log_std)
|
78
|
+
latents = mean + std * jax.random.normal(rngkey, mean.shape, dtype=mean.dtype)
|
79
|
+
else:
|
80
|
+
latents, _ = jnp.split(latents, 2, axis=-1)
|
81
|
+
latents *= scaling_factor
|
82
|
+
return latents
|
83
|
+
|
84
|
+
def decode_single_frame(latents):
|
85
|
+
latents = (1.0 / scaling_factor) * latents
|
86
|
+
latents = post_quant_conv.apply({"params": params['post_quant_conv']}, latents)
|
87
|
+
return dec.apply({"params": params['decoder']}, latents)
|
88
|
+
|
89
|
+
self.encode_single_frame = jax.jit(encode_single_frame)
|
90
|
+
self.decode_single_frame = jax.jit(decode_single_frame)
|
91
|
+
|
92
|
+
# Calculate downscale factor by passing a dummy input through the encoder
|
93
|
+
print("Calculating downscale factor...")
|
94
|
+
dummy_input = jnp.ones((1, 128, 128, 3), dtype=dtype)
|
95
|
+
dummy_latents = self.encode_single_frame(dummy_input)
|
96
|
+
_, h, w, c = dummy_latents.shape
|
97
|
+
_, H, W, C = dummy_input.shape
|
98
|
+
self.__downscale_factor__ = H // h
|
99
|
+
self.__latent_channels__ = c
|
100
|
+
print(f"Downscale factor: {self.__downscale_factor__}")
|
101
|
+
print(f"Latent channels: {self.__latent_channels__}")
|
102
|
+
|
103
|
+
def __encode__(self, images, key: jax.random.PRNGKey = None, **kwargs):
|
104
|
+
"""Encode a batch of images to latent representations.
|
105
|
+
|
106
|
+
Implements the abstract method from the parent class.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
images: Image tensor of shape [B, H, W, C]
|
110
|
+
key: Optional random key for stochastic encoding
|
111
|
+
**kwargs: Additional arguments (unused)
|
112
|
+
|
113
|
+
Returns:
|
114
|
+
Latent representations of shape [B, h, w, c]
|
115
|
+
"""
|
116
|
+
return self.encode_single_frame(images, key)
|
117
|
+
|
118
|
+
def __decode__(self, latents, **kwargs):
|
119
|
+
"""Decode latent representations to images.
|
120
|
+
|
121
|
+
Implements the abstract method from the parent class.
|
122
|
+
|
123
|
+
Args:
|
124
|
+
latents: Latent tensor of shape [B, h, w, c]
|
125
|
+
**kwargs: Additional arguments (unused)
|
126
|
+
|
127
|
+
Returns:
|
128
|
+
Decoded images of shape [B, H, W, C]
|
129
|
+
"""
|
130
|
+
return self.decode_single_frame(latents)
|
131
|
+
|
132
|
+
@property
|
133
|
+
def downscale_factor(self) -> int:
|
134
|
+
"""Returns the downscale factor for the encoder."""
|
135
|
+
return self.__downscale_factor__
|
136
|
+
|
137
|
+
@property
|
138
|
+
def latent_channels(self) -> int:
|
139
|
+
"""Returns the number of channels in the latent space."""
|
140
|
+
return self.__latent_channels__
|
141
|
+
|
142
|
+
@property
|
143
|
+
def name(self) -> str:
|
144
|
+
"""Get the name of the autoencoder model."""
|
145
|
+
return "stable_diffusion"
|
87
146
|
|
88
|
-
def
|
89
|
-
|
90
|
-
|
91
|
-
|
147
|
+
def serialize(self):
|
148
|
+
"""Serialize the model to a dictionary format."""
|
149
|
+
return {
|
150
|
+
"modelname": self.modelname,
|
151
|
+
"revision": self.revision,
|
152
|
+
"dtype": str(self.dtype),
|
153
|
+
}
|
154
|
+
|
@@ -6,21 +6,53 @@ from flax.typing import Dtype, PrecisionLike
|
|
6
6
|
from .autoencoder import AutoEncoder
|
7
7
|
|
8
8
|
class SimpleAutoEncoder(AutoEncoder):
|
9
|
+
"""A simple autoencoder implementation using the abstract method pattern.
|
10
|
+
|
11
|
+
This implementation allows for handling both image and video data through
|
12
|
+
the parent class's handling of video reshaping.
|
13
|
+
"""
|
9
14
|
latent_channels: int
|
10
15
|
feature_depths: List[int]=[64, 128, 256, 512]
|
11
|
-
attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}]
|
16
|
+
attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}]
|
12
17
|
num_res_blocks: int=2
|
13
|
-
num_middle_res_blocks:int=1
|
18
|
+
num_middle_res_blocks:int=1
|
14
19
|
activation:Callable = jax.nn.swish
|
15
20
|
norm_groups:int=8
|
16
21
|
dtype: Optional[Dtype] = None
|
17
22
|
precision: PrecisionLike = None
|
18
23
|
|
19
|
-
|
24
|
+
def __encode__(self, x: jnp.ndarray, **kwargs):
|
25
|
+
"""Encode a batch of images to latent representations.
|
26
|
+
|
27
|
+
Implements the abstract method from the parent class.
|
20
28
|
|
29
|
+
Args:
|
30
|
+
x: Image tensor of shape [B, H, W, C]
|
31
|
+
**kwargs: Additional arguments
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
Latent representations of shape [B, h, w, c]
|
35
|
+
"""
|
36
|
+
# TODO: Implement the actual encoding logic for single frames
|
37
|
+
# This is just a placeholder implementation
|
38
|
+
B, H, W, C = x.shape
|
39
|
+
h, w = H // 8, W // 8 # Example downsampling factor
|
40
|
+
return jnp.zeros((B, h, w, self.latent_channels))
|
21
41
|
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
42
|
+
def __decode__(self, z: jnp.ndarray, **kwargs):
|
43
|
+
"""Decode latent representations to images.
|
44
|
+
|
45
|
+
Implements the abstract method from the parent class.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
z: Latent tensor of shape [B, h, w, c]
|
49
|
+
**kwargs: Additional arguments
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
Decoded images of shape [B, H, W, C]
|
53
|
+
"""
|
54
|
+
# TODO: Implement the actual decoding logic for single frames
|
55
|
+
# This is just a placeholder implementation
|
56
|
+
B, h, w, c = z.shape
|
57
|
+
H, W = h * 8, w * 8 # Example upsampling factor
|
58
|
+
return jnp.zeros((B, H, W, 3))
|
flaxdiff/models/common.py
CHANGED
@@ -108,13 +108,16 @@ class FourierEmbedding(nn.Module):
|
|
108
108
|
class TimeProjection(nn.Module):
|
109
109
|
features:int
|
110
110
|
activation:Callable=jax.nn.gelu
|
111
|
-
kernel_init:Callable=kernel_init(1.0)
|
112
111
|
|
113
112
|
@nn.compact
|
114
113
|
def __call__(self, x):
|
115
|
-
x = nn.DenseGeneral(
|
114
|
+
x = nn.DenseGeneral(
|
115
|
+
self.features,
|
116
|
+
)(x)
|
116
117
|
x = self.activation(x)
|
117
|
-
x = nn.DenseGeneral(
|
118
|
+
x = nn.DenseGeneral(
|
119
|
+
self.features,
|
120
|
+
)(x)
|
118
121
|
x = self.activation(x)
|
119
122
|
return x
|
120
123
|
|
@@ -123,7 +126,6 @@ class SeparableConv(nn.Module):
|
|
123
126
|
kernel_size:tuple=(3, 3)
|
124
127
|
strides:tuple=(1, 1)
|
125
128
|
use_bias:bool=False
|
126
|
-
kernel_init:Callable=kernel_init(1.0)
|
127
129
|
padding:str="SAME"
|
128
130
|
dtype: Optional[Dtype] = None
|
129
131
|
precision: PrecisionLike = None
|
@@ -133,7 +135,7 @@ class SeparableConv(nn.Module):
|
|
133
135
|
in_features = x.shape[-1]
|
134
136
|
depthwise = nn.Conv(
|
135
137
|
features=in_features, kernel_size=self.kernel_size,
|
136
|
-
strides=self.strides,
|
138
|
+
strides=self.strides,
|
137
139
|
feature_group_count=in_features, use_bias=self.use_bias,
|
138
140
|
padding=self.padding,
|
139
141
|
dtype=self.dtype,
|
@@ -141,7 +143,7 @@ class SeparableConv(nn.Module):
|
|
141
143
|
)(x)
|
142
144
|
pointwise = nn.Conv(
|
143
145
|
features=self.features, kernel_size=(1, 1),
|
144
|
-
strides=(1, 1),
|
146
|
+
strides=(1, 1),
|
145
147
|
use_bias=self.use_bias,
|
146
148
|
dtype=self.dtype,
|
147
149
|
precision=self.precision
|
@@ -153,7 +155,6 @@ class ConvLayer(nn.Module):
|
|
153
155
|
features:int
|
154
156
|
kernel_size:tuple=(3, 3)
|
155
157
|
strides:tuple=(1, 1)
|
156
|
-
kernel_init:Callable=kernel_init(1.0)
|
157
158
|
dtype: Optional[Dtype] = None
|
158
159
|
precision: PrecisionLike = None
|
159
160
|
|
@@ -164,7 +165,6 @@ class ConvLayer(nn.Module):
|
|
164
165
|
features=self.features,
|
165
166
|
kernel_size=self.kernel_size,
|
166
167
|
strides=self.strides,
|
167
|
-
kernel_init=self.kernel_init,
|
168
168
|
dtype=self.dtype,
|
169
169
|
precision=self.precision
|
170
170
|
)
|
@@ -183,7 +183,6 @@ class ConvLayer(nn.Module):
|
|
183
183
|
features=self.features,
|
184
184
|
kernel_size=self.kernel_size,
|
185
185
|
strides=self.strides,
|
186
|
-
kernel_init=self.kernel_init,
|
187
186
|
dtype=self.dtype,
|
188
187
|
precision=self.precision
|
189
188
|
)
|
@@ -192,7 +191,6 @@ class ConvLayer(nn.Module):
|
|
192
191
|
features=self.features,
|
193
192
|
kernel_size=self.kernel_size,
|
194
193
|
strides=self.strides,
|
195
|
-
kernel_init=self.kernel_init,
|
196
194
|
dtype=self.dtype,
|
197
195
|
precision=self.precision
|
198
196
|
)
|
@@ -206,7 +204,6 @@ class Upsample(nn.Module):
|
|
206
204
|
activation:Callable=jax.nn.swish
|
207
205
|
dtype: Optional[Dtype] = None
|
208
206
|
precision: PrecisionLike = None
|
209
|
-
kernel_init:Callable=kernel_init(1.0)
|
210
207
|
|
211
208
|
@nn.compact
|
212
209
|
def __call__(self, x, residual=None):
|
@@ -221,7 +218,6 @@ class Upsample(nn.Module):
|
|
221
218
|
strides=(1, 1),
|
222
219
|
dtype=self.dtype,
|
223
220
|
precision=self.precision,
|
224
|
-
kernel_init=self.kernel_init
|
225
221
|
)(out)
|
226
222
|
if residual is not None:
|
227
223
|
out = jnp.concatenate([out, residual], axis=-1)
|
@@ -233,7 +229,6 @@ class Downsample(nn.Module):
|
|
233
229
|
activation:Callable=jax.nn.swish
|
234
230
|
dtype: Optional[Dtype] = None
|
235
231
|
precision: PrecisionLike = None
|
236
|
-
kernel_init:Callable=kernel_init(1.0)
|
237
232
|
|
238
233
|
@nn.compact
|
239
234
|
def __call__(self, x, residual=None):
|
@@ -244,7 +239,6 @@ class Downsample(nn.Module):
|
|
244
239
|
strides=(2, 2),
|
245
240
|
dtype=self.dtype,
|
246
241
|
precision=self.precision,
|
247
|
-
kernel_init=self.kernel_init
|
248
242
|
)(x)
|
249
243
|
if residual is not None:
|
250
244
|
if residual.shape[1] > out.shape[1]:
|
@@ -269,7 +263,6 @@ class ResidualBlock(nn.Module):
|
|
269
263
|
direction:str=None
|
270
264
|
res:int=2
|
271
265
|
norm_groups:int=8
|
272
|
-
kernel_init:Callable=kernel_init(1.0)
|
273
266
|
dtype: Optional[Dtype] = None
|
274
267
|
precision: PrecisionLike = None
|
275
268
|
named_norms:bool=False
|
@@ -296,7 +289,6 @@ class ResidualBlock(nn.Module):
|
|
296
289
|
features=self.features,
|
297
290
|
kernel_size=self.kernel_size,
|
298
291
|
strides=self.strides,
|
299
|
-
kernel_init=self.kernel_init,
|
300
292
|
name="conv1",
|
301
293
|
dtype=self.dtype,
|
302
294
|
precision=self.precision
|
@@ -321,7 +313,6 @@ class ResidualBlock(nn.Module):
|
|
321
313
|
features=self.features,
|
322
314
|
kernel_size=self.kernel_size,
|
323
315
|
strides=self.strides,
|
324
|
-
kernel_init=self.kernel_init,
|
325
316
|
name="conv2",
|
326
317
|
dtype=self.dtype,
|
327
318
|
precision=self.precision
|
@@ -333,7 +324,6 @@ class ResidualBlock(nn.Module):
|
|
333
324
|
features=self.features,
|
334
325
|
kernel_size=(1, 1),
|
335
326
|
strides=1,
|
336
|
-
kernel_init=self.kernel_init,
|
337
327
|
name="residual_conv",
|
338
328
|
dtype=self.dtype,
|
339
329
|
precision=self.precision
|
flaxdiff/models/simple_unet.py
CHANGED
@@ -10,17 +10,16 @@ from functools import partial
|
|
10
10
|
|
11
11
|
class Unet(nn.Module):
|
12
12
|
output_channels:int=3
|
13
|
-
emb_features:int=64*4
|
14
|
-
feature_depths:list=
|
15
|
-
attention_configs:list=
|
16
|
-
num_res_blocks:int=2
|
17
|
-
num_middle_res_blocks:int=1
|
13
|
+
emb_features:int=64*4
|
14
|
+
feature_depths:list=(64, 128, 256, 512)
|
15
|
+
attention_configs:list=({"heads":8}, {"heads":8}, {"heads":8}, {"heads":8})
|
16
|
+
num_res_blocks:int=2
|
17
|
+
num_middle_res_blocks:int=1
|
18
18
|
activation:Callable = jax.nn.swish
|
19
19
|
norm_groups:int=8
|
20
20
|
dtype: Optional[Dtype] = None
|
21
21
|
precision: PrecisionLike = None
|
22
22
|
named_norms: bool = False # This is for backward compatibility reasons; older checkpoints have named norms
|
23
|
-
kernel_init: Callable = partial(kernel_init, dtype=jnp.float32)
|
24
23
|
|
25
24
|
def setup(self):
|
26
25
|
if self.norm_groups > 0:
|
@@ -50,7 +49,6 @@ class Unet(nn.Module):
|
|
50
49
|
features=self.feature_depths[0],
|
51
50
|
kernel_size=(3, 3),
|
52
51
|
strides=(1, 1),
|
53
|
-
kernel_init=self.kernel_init(scale=1.0),
|
54
52
|
dtype=self.dtype,
|
55
53
|
precision=self.precision
|
56
54
|
)(x)
|
@@ -65,7 +63,6 @@ class Unet(nn.Module):
|
|
65
63
|
down_conv_type,
|
66
64
|
name=f"down_{i}_residual_{j}",
|
67
65
|
features=dim_in,
|
68
|
-
kernel_init=self.kernel_init(scale=1.0),
|
69
66
|
kernel_size=(3, 3),
|
70
67
|
strides=(1, 1),
|
71
68
|
activation=self.activation,
|
@@ -85,7 +82,6 @@ class Unet(nn.Module):
|
|
85
82
|
force_fp32_for_softmax=attention_config.get("force_fp32_for_softmax", False),
|
86
83
|
norm_inputs=attention_config.get("norm_inputs", True),
|
87
84
|
explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
|
88
|
-
kernel_init=self.kernel_init(scale=1.0),
|
89
85
|
name=f"down_{i}_attention_{j}")(x, textcontext)
|
90
86
|
# print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
|
91
87
|
downs.append(x)
|
@@ -108,7 +104,6 @@ class Unet(nn.Module):
|
|
108
104
|
middle_conv_type,
|
109
105
|
name=f"middle_res1_{j}",
|
110
106
|
features=middle_dim_out,
|
111
|
-
kernel_init=self.kernel_init(scale=1.0),
|
112
107
|
kernel_size=(3, 3),
|
113
108
|
strides=(1, 1),
|
114
109
|
activation=self.activation,
|
@@ -129,13 +124,11 @@ class Unet(nn.Module):
|
|
129
124
|
force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
|
130
125
|
norm_inputs=middle_attention.get("norm_inputs", True),
|
131
126
|
explicitly_add_residual=middle_attention.get("explicitly_add_residual", True),
|
132
|
-
kernel_init=self.kernel_init(scale=1.0),
|
133
127
|
name=f"middle_attention_{j}")(x, textcontext)
|
134
128
|
x = ResidualBlock(
|
135
129
|
middle_conv_type,
|
136
130
|
name=f"middle_res2_{j}",
|
137
131
|
features=middle_dim_out,
|
138
|
-
kernel_init=self.kernel_init(scale=1.0),
|
139
132
|
kernel_size=(3, 3),
|
140
133
|
strides=(1, 1),
|
141
134
|
activation=self.activation,
|
@@ -157,7 +150,6 @@ class Unet(nn.Module):
|
|
157
150
|
up_conv_type,# if j == 0 else "separable",
|
158
151
|
name=f"up_{i}_residual_{j}",
|
159
152
|
features=dim_out,
|
160
|
-
kernel_init=self.kernel_init(scale=1.0),
|
161
153
|
kernel_size=kernel_size,
|
162
154
|
strides=(1, 1),
|
163
155
|
activation=self.activation,
|
@@ -177,7 +169,6 @@ class Unet(nn.Module):
|
|
177
169
|
force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
|
178
170
|
norm_inputs=attention_config.get("norm_inputs", True),
|
179
171
|
explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
|
180
|
-
kernel_init=self.kernel_init(scale=1.0),
|
181
172
|
name=f"up_{i}_attention_{j}")(x, textcontext)
|
182
173
|
# print("Upscaling ", i, x.shape)
|
183
174
|
if i != len(feature_depths) - 1:
|
@@ -196,7 +187,6 @@ class Unet(nn.Module):
|
|
196
187
|
features=self.feature_depths[0],
|
197
188
|
kernel_size=(3, 3),
|
198
189
|
strides=(1, 1),
|
199
|
-
kernel_init=self.kernel_init(scale=1.0),
|
200
190
|
dtype=self.dtype,
|
201
191
|
precision=self.precision
|
202
192
|
)(x)
|
@@ -207,7 +197,6 @@ class Unet(nn.Module):
|
|
207
197
|
conv_type,
|
208
198
|
name="final_residual",
|
209
199
|
features=self.feature_depths[0],
|
210
|
-
kernel_init=self.kernel_init(scale=1.0),
|
211
200
|
kernel_size=(3,3),
|
212
201
|
strides=(1, 1),
|
213
202
|
activation=self.activation,
|
@@ -226,7 +215,7 @@ class Unet(nn.Module):
|
|
226
215
|
kernel_size=(3, 3),
|
227
216
|
strides=(1, 1),
|
228
217
|
# activation=jax.nn.mish
|
229
|
-
kernel_init=self.kernel_init(scale=0.0),
|
218
|
+
# kernel_init=self.kernel_init(scale=0.0),
|
230
219
|
dtype=self.dtype,
|
231
220
|
precision=self.precision
|
232
221
|
)(x)
|
flaxdiff/models/simple_vit.py
CHANGED
@@ -23,7 +23,6 @@ class PatchEmbedding(nn.Module):
|
|
23
23
|
embedding_dim: int
|
24
24
|
dtype: Any = jnp.float32
|
25
25
|
precision: Any = jax.lax.Precision.HIGH
|
26
|
-
kernel_init: Callable = partial(kernel_init, 1.0)
|
27
26
|
|
28
27
|
@nn.compact
|
29
28
|
def __call__(self, x):
|
@@ -34,7 +33,6 @@ class PatchEmbedding(nn.Module):
|
|
34
33
|
kernel_size=(self.patch_size, self.patch_size),
|
35
34
|
strides=(self.patch_size, self.patch_size),
|
36
35
|
dtype=self.dtype,
|
37
|
-
kernel_init=self.kernel_init(),
|
38
36
|
precision=self.precision)(x)
|
39
37
|
x = jnp.reshape(x, (batch, -1, self.embedding_dim))
|
40
38
|
return x
|
@@ -53,7 +51,7 @@ class PositionalEncoding(nn.Module):
|
|
53
51
|
class UViT(nn.Module):
|
54
52
|
output_channels:int=3
|
55
53
|
patch_size: int = 16
|
56
|
-
emb_features:int=768
|
54
|
+
emb_features:int=768
|
57
55
|
num_layers: int = 12
|
58
56
|
num_heads: int = 12
|
59
57
|
dropout_rate: float = 0.1
|
@@ -67,7 +65,7 @@ class UViT(nn.Module):
|
|
67
65
|
norm_groups:int=8
|
68
66
|
dtype: Optional[Dtype] = None
|
69
67
|
precision: PrecisionLike = None
|
70
|
-
kernel_init: Callable = partial(kernel_init, scale=1.0)
|
68
|
+
# kernel_init: Callable = partial(kernel_init, scale=1.0)
|
71
69
|
add_residualblock_output: bool = False
|
72
70
|
norm_inputs: bool = False
|
73
71
|
explicitly_add_residual: bool = True
|
@@ -88,10 +86,10 @@ class UViT(nn.Module):
|
|
88
86
|
|
89
87
|
# Patch embedding
|
90
88
|
x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features,
|
91
|
-
dtype=self.dtype, precision=self.precision
|
89
|
+
dtype=self.dtype, precision=self.precision)(x)
|
92
90
|
num_patches = x.shape[1]
|
93
91
|
|
94
|
-
context_emb = nn.DenseGeneral(features=self.emb_features,
|
92
|
+
context_emb = nn.DenseGeneral(features=self.emb_features,
|
95
93
|
dtype=self.dtype, precision=self.precision)(textcontext)
|
96
94
|
num_text_tokens = textcontext.shape[1]
|
97
95
|
|
@@ -116,7 +114,7 @@ class UViT(nn.Module):
|
|
116
114
|
only_pure_attention=False,
|
117
115
|
norm_inputs=self.norm_inputs,
|
118
116
|
explicitly_add_residual=self.explicitly_add_residual,
|
119
|
-
|
117
|
+
)(x)
|
120
118
|
skips.append(x)
|
121
119
|
|
122
120
|
# Middle block
|
@@ -126,12 +124,12 @@ class UViT(nn.Module):
|
|
126
124
|
only_pure_attention=False,
|
127
125
|
norm_inputs=self.norm_inputs,
|
128
126
|
explicitly_add_residual=self.explicitly_add_residual,
|
129
|
-
|
127
|
+
)(x)
|
130
128
|
|
131
129
|
# # Out blocks
|
132
130
|
for i in range(self.num_layers // 2):
|
133
131
|
x = jnp.concatenate([x, skips.pop()], axis=-1)
|
134
|
-
x = nn.DenseGeneral(features=self.emb_features,
|
132
|
+
x = nn.DenseGeneral(features=self.emb_features,
|
135
133
|
dtype=self.dtype, precision=self.precision)(x)
|
136
134
|
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
|
137
135
|
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
@@ -139,13 +137,13 @@ class UViT(nn.Module):
|
|
139
137
|
only_pure_attention=False,
|
140
138
|
norm_inputs=self.norm_inputs,
|
141
139
|
explicitly_add_residual=self.explicitly_add_residual,
|
142
|
-
|
140
|
+
)(x)
|
143
141
|
|
144
142
|
# print(f'Shape of x after transformer blocks: {x.shape}')
|
145
143
|
x = self.norm()(x)
|
146
144
|
|
147
145
|
patch_dim = self.patch_size ** 2 * self.output_channels
|
148
|
-
x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision
|
146
|
+
x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision)(x)
|
149
147
|
x = x[:, 1 + num_text_tokens:, :]
|
150
148
|
x = unpatchify(x, channels=self.output_channels)
|
151
149
|
|
@@ -159,7 +157,6 @@ class UViT(nn.Module):
|
|
159
157
|
kernel_size=(3, 3),
|
160
158
|
strides=(1, 1),
|
161
159
|
# activation=jax.nn.mish
|
162
|
-
kernel_init=self.kernel_init(scale=0.0),
|
163
160
|
dtype=self.dtype,
|
164
161
|
precision=self.precision
|
165
162
|
)(x)
|
@@ -173,7 +170,6 @@ class UViT(nn.Module):
|
|
173
170
|
kernel_size=(3, 3),
|
174
171
|
strides=(1, 1),
|
175
172
|
# activation=jax.nn.mish
|
176
|
-
kernel_init=self.kernel_init(scale=0.0),
|
177
173
|
dtype=self.dtype,
|
178
174
|
precision=self.precision
|
179
175
|
)(x)
|