flaxdiff 0.1.38__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. flaxdiff/data/__init__.py +5 -1
  2. flaxdiff/data/benchmark_decord.py +443 -0
  3. flaxdiff/data/dataloaders.py +608 -0
  4. flaxdiff/data/dataset_map.py +61 -6
  5. flaxdiff/data/online_loader.py +779 -150
  6. flaxdiff/data/sources/audio_utils.py +142 -0
  7. flaxdiff/data/sources/av_example.py +125 -0
  8. flaxdiff/data/sources/av_utils.py +590 -0
  9. flaxdiff/data/sources/base.py +129 -0
  10. flaxdiff/data/sources/images.py +309 -0
  11. flaxdiff/data/sources/utils.py +158 -0
  12. flaxdiff/data/sources/videos.py +250 -0
  13. flaxdiff/data/sources/voxceleb2.py +412 -0
  14. flaxdiff/inference/__init__.py +0 -0
  15. flaxdiff/inference/pipeline.py +260 -0
  16. flaxdiff/inference/utils.py +320 -0
  17. flaxdiff/inputs/__init__.py +173 -0
  18. flaxdiff/inputs/encoders.py +98 -0
  19. flaxdiff/models/__init__.py +2 -1
  20. flaxdiff/models/attention.py +22 -16
  21. flaxdiff/models/autoencoder/autoencoder.py +141 -9
  22. flaxdiff/models/autoencoder/diffusers.py +88 -25
  23. flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
  24. flaxdiff/models/common.py +8 -18
  25. flaxdiff/models/simple_unet.py +6 -17
  26. flaxdiff/models/simple_vit.py +9 -13
  27. flaxdiff/models/unet_3d.py +446 -0
  28. flaxdiff/models/unet_3d_blocks.py +505 -0
  29. flaxdiff/samplers/common.py +358 -96
  30. flaxdiff/samplers/ddim.py +44 -5
  31. flaxdiff/schedulers/karras.py +20 -12
  32. flaxdiff/trainer/__init__.py +2 -1
  33. flaxdiff/trainer/autoencoder_trainer.py +1 -2
  34. flaxdiff/trainer/diffusion_trainer.py +35 -29
  35. flaxdiff/trainer/general_diffusion_trainer.py +583 -0
  36. flaxdiff/trainer/simple_trainer.py +51 -16
  37. flaxdiff/utils.py +128 -57
  38. {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
  39. flaxdiff-0.2.0.dist-info/RECORD +64 -0
  40. {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
  41. flaxdiff/data/datasets.py +0 -169
  42. flaxdiff/data/sources/gcs.py +0 -81
  43. flaxdiff/data/sources/tfds.py +0 -79
  44. flaxdiff/trainer/video_diffusion_trainer.py +0 -62
  45. flaxdiff-0.1.38.dist-info/RECORD +0 -50
  46. {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,9 @@ class StableDiffusionVAE(AutoEncoder):
22
22
  dtype=dtype,
23
23
  )
24
24
 
25
- # vae = pipeline.vae
25
+ self.modelname = modelname
26
+ self.revision = revision
27
+ self.dtype = dtype
26
28
 
27
29
  enc = FlaxEncoder(
28
30
  in_channels=vae.config.in_channels,
@@ -63,29 +65,90 @@ class StableDiffusionVAE(AutoEncoder):
63
65
  dtype=vae.dtype,
64
66
  )
65
67
 
66
- self.enc = enc
67
- self.dec = dec
68
- self.post_quant_conv = post_quant_conv
69
- self.quant_conv = quant_conv
70
- self.params = params
71
- self.scaling_factor = vae.scaling_factor
68
+ scaling_factor = vae.scaling_factor
69
+ print(f"Scaling factor: {scaling_factor}")
72
70
 
73
- def encode(self, images, rngkey: jax.random.PRNGKey = None):
74
- latents = self.enc.apply({"params": self.params["vae"]['encoder']}, images, deterministic=True)
75
- latents = self.quant_conv.apply({"params": self.params["vae"]['quant_conv']}, latents)
76
- if rngkey is not None:
77
- mean, log_std = jnp.split(latents, 2, axis=-1)
78
- log_std = jnp.clip(log_std, -30, 20)
79
- std = jnp.exp(0.5 * log_std)
80
- latents = mean + std * jax.random.normal(rngkey, mean.shape, dtype=mean.dtype)
81
- # print("Sampled")
82
- else:
83
- # return the mean
84
- latents, _ = jnp.split(latents, 2, axis=-1)
85
- latents *= self.scaling_factor
86
- return latents
71
+ def encode_single_frame(images, rngkey: jax.random.PRNGKey = None):
72
+ latents = enc.apply({"params": params['encoder']}, images, deterministic=True)
73
+ latents = quant_conv.apply({"params": params['quant_conv']}, latents)
74
+ if rngkey is not None:
75
+ mean, log_std = jnp.split(latents, 2, axis=-1)
76
+ log_std = jnp.clip(log_std, -30, 20)
77
+ std = jnp.exp(0.5 * log_std)
78
+ latents = mean + std * jax.random.normal(rngkey, mean.shape, dtype=mean.dtype)
79
+ else:
80
+ latents, _ = jnp.split(latents, 2, axis=-1)
81
+ latents *= scaling_factor
82
+ return latents
83
+
84
+ def decode_single_frame(latents):
85
+ latents = (1.0 / scaling_factor) * latents
86
+ latents = post_quant_conv.apply({"params": params['post_quant_conv']}, latents)
87
+ return dec.apply({"params": params['decoder']}, latents)
88
+
89
+ self.encode_single_frame = jax.jit(encode_single_frame)
90
+ self.decode_single_frame = jax.jit(decode_single_frame)
91
+
92
+ # Calculate downscale factor by passing a dummy input through the encoder
93
+ print("Calculating downscale factor...")
94
+ dummy_input = jnp.ones((1, 128, 128, 3), dtype=dtype)
95
+ dummy_latents = self.encode_single_frame(dummy_input)
96
+ _, h, w, c = dummy_latents.shape
97
+ _, H, W, C = dummy_input.shape
98
+ self.__downscale_factor__ = H // h
99
+ self.__latent_channels__ = c
100
+ print(f"Downscale factor: {self.__downscale_factor__}")
101
+ print(f"Latent channels: {self.__latent_channels__}")
102
+
103
+ def __encode__(self, images, key: jax.random.PRNGKey = None, **kwargs):
104
+ """Encode a batch of images to latent representations.
105
+
106
+ Implements the abstract method from the parent class.
107
+
108
+ Args:
109
+ images: Image tensor of shape [B, H, W, C]
110
+ key: Optional random key for stochastic encoding
111
+ **kwargs: Additional arguments (unused)
112
+
113
+ Returns:
114
+ Latent representations of shape [B, h, w, c]
115
+ """
116
+ return self.encode_single_frame(images, key)
117
+
118
+ def __decode__(self, latents, **kwargs):
119
+ """Decode latent representations to images.
120
+
121
+ Implements the abstract method from the parent class.
122
+
123
+ Args:
124
+ latents: Latent tensor of shape [B, h, w, c]
125
+ **kwargs: Additional arguments (unused)
126
+
127
+ Returns:
128
+ Decoded images of shape [B, H, W, C]
129
+ """
130
+ return self.decode_single_frame(latents)
131
+
132
+ @property
133
+ def downscale_factor(self) -> int:
134
+ """Returns the downscale factor for the encoder."""
135
+ return self.__downscale_factor__
136
+
137
+ @property
138
+ def latent_channels(self) -> int:
139
+ """Returns the number of channels in the latent space."""
140
+ return self.__latent_channels__
141
+
142
+ @property
143
+ def name(self) -> str:
144
+ """Get the name of the autoencoder model."""
145
+ return "stable_diffusion"
87
146
 
88
- def decode(self, latents):
89
- latents = (1.0 / self.scaling_factor) * latents
90
- latents = self.post_quant_conv.apply({"params": self.params["vae"]['post_quant_conv']}, latents)
91
- return self.dec.apply({"params": self.params["vae"]['decoder']}, latents)
147
+ def serialize(self):
148
+ """Serialize the model to a dictionary format."""
149
+ return {
150
+ "modelname": self.modelname,
151
+ "revision": self.revision,
152
+ "dtype": str(self.dtype),
153
+ }
154
+
@@ -6,21 +6,53 @@ from flax.typing import Dtype, PrecisionLike
6
6
  from .autoencoder import AutoEncoder
7
7
 
8
8
  class SimpleAutoEncoder(AutoEncoder):
9
+ """A simple autoencoder implementation using the abstract method pattern.
10
+
11
+ This implementation allows for handling both image and video data through
12
+ the parent class's handling of video reshaping.
13
+ """
9
14
  latent_channels: int
10
15
  feature_depths: List[int]=[64, 128, 256, 512]
11
- attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}],
16
+ attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}]
12
17
  num_res_blocks: int=2
13
- num_middle_res_blocks:int=1,
18
+ num_middle_res_blocks:int=1
14
19
  activation:Callable = jax.nn.swish
15
20
  norm_groups:int=8
16
21
  dtype: Optional[Dtype] = None
17
22
  precision: PrecisionLike = None
18
23
 
19
- # def encode(self, x: jnp.ndarray):
24
+ def __encode__(self, x: jnp.ndarray, **kwargs):
25
+ """Encode a batch of images to latent representations.
26
+
27
+ Implements the abstract method from the parent class.
20
28
 
29
+ Args:
30
+ x: Image tensor of shape [B, H, W, C]
31
+ **kwargs: Additional arguments
32
+
33
+ Returns:
34
+ Latent representations of shape [B, h, w, c]
35
+ """
36
+ # TODO: Implement the actual encoding logic for single frames
37
+ # This is just a placeholder implementation
38
+ B, H, W, C = x.shape
39
+ h, w = H // 8, W // 8 # Example downsampling factor
40
+ return jnp.zeros((B, h, w, self.latent_channels))
21
41
 
22
- @nn.compact
23
- def __call__(self, x: jnp.ndarray):
24
- latents = self.encode(x)
25
- reconstructions = self.decode(latents)
26
- return reconstructions
42
+ def __decode__(self, z: jnp.ndarray, **kwargs):
43
+ """Decode latent representations to images.
44
+
45
+ Implements the abstract method from the parent class.
46
+
47
+ Args:
48
+ z: Latent tensor of shape [B, h, w, c]
49
+ **kwargs: Additional arguments
50
+
51
+ Returns:
52
+ Decoded images of shape [B, H, W, C]
53
+ """
54
+ # TODO: Implement the actual decoding logic for single frames
55
+ # This is just a placeholder implementation
56
+ B, h, w, c = z.shape
57
+ H, W = h * 8, w * 8 # Example upsampling factor
58
+ return jnp.zeros((B, H, W, 3))
flaxdiff/models/common.py CHANGED
@@ -108,13 +108,16 @@ class FourierEmbedding(nn.Module):
108
108
  class TimeProjection(nn.Module):
109
109
  features:int
110
110
  activation:Callable=jax.nn.gelu
111
- kernel_init:Callable=kernel_init(1.0)
112
111
 
113
112
  @nn.compact
114
113
  def __call__(self, x):
115
- x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init)(x)
114
+ x = nn.DenseGeneral(
115
+ self.features,
116
+ )(x)
116
117
  x = self.activation(x)
117
- x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init)(x)
118
+ x = nn.DenseGeneral(
119
+ self.features,
120
+ )(x)
118
121
  x = self.activation(x)
119
122
  return x
120
123
 
@@ -123,7 +126,6 @@ class SeparableConv(nn.Module):
123
126
  kernel_size:tuple=(3, 3)
124
127
  strides:tuple=(1, 1)
125
128
  use_bias:bool=False
126
- kernel_init:Callable=kernel_init(1.0)
127
129
  padding:str="SAME"
128
130
  dtype: Optional[Dtype] = None
129
131
  precision: PrecisionLike = None
@@ -133,7 +135,7 @@ class SeparableConv(nn.Module):
133
135
  in_features = x.shape[-1]
134
136
  depthwise = nn.Conv(
135
137
  features=in_features, kernel_size=self.kernel_size,
136
- strides=self.strides, kernel_init=self.kernel_init,
138
+ strides=self.strides,
137
139
  feature_group_count=in_features, use_bias=self.use_bias,
138
140
  padding=self.padding,
139
141
  dtype=self.dtype,
@@ -141,7 +143,7 @@ class SeparableConv(nn.Module):
141
143
  )(x)
142
144
  pointwise = nn.Conv(
143
145
  features=self.features, kernel_size=(1, 1),
144
- strides=(1, 1), kernel_init=self.kernel_init,
146
+ strides=(1, 1),
145
147
  use_bias=self.use_bias,
146
148
  dtype=self.dtype,
147
149
  precision=self.precision
@@ -153,7 +155,6 @@ class ConvLayer(nn.Module):
153
155
  features:int
154
156
  kernel_size:tuple=(3, 3)
155
157
  strides:tuple=(1, 1)
156
- kernel_init:Callable=kernel_init(1.0)
157
158
  dtype: Optional[Dtype] = None
158
159
  precision: PrecisionLike = None
159
160
 
@@ -164,7 +165,6 @@ class ConvLayer(nn.Module):
164
165
  features=self.features,
165
166
  kernel_size=self.kernel_size,
166
167
  strides=self.strides,
167
- kernel_init=self.kernel_init,
168
168
  dtype=self.dtype,
169
169
  precision=self.precision
170
170
  )
@@ -183,7 +183,6 @@ class ConvLayer(nn.Module):
183
183
  features=self.features,
184
184
  kernel_size=self.kernel_size,
185
185
  strides=self.strides,
186
- kernel_init=self.kernel_init,
187
186
  dtype=self.dtype,
188
187
  precision=self.precision
189
188
  )
@@ -192,7 +191,6 @@ class ConvLayer(nn.Module):
192
191
  features=self.features,
193
192
  kernel_size=self.kernel_size,
194
193
  strides=self.strides,
195
- kernel_init=self.kernel_init,
196
194
  dtype=self.dtype,
197
195
  precision=self.precision
198
196
  )
@@ -206,7 +204,6 @@ class Upsample(nn.Module):
206
204
  activation:Callable=jax.nn.swish
207
205
  dtype: Optional[Dtype] = None
208
206
  precision: PrecisionLike = None
209
- kernel_init:Callable=kernel_init(1.0)
210
207
 
211
208
  @nn.compact
212
209
  def __call__(self, x, residual=None):
@@ -221,7 +218,6 @@ class Upsample(nn.Module):
221
218
  strides=(1, 1),
222
219
  dtype=self.dtype,
223
220
  precision=self.precision,
224
- kernel_init=self.kernel_init
225
221
  )(out)
226
222
  if residual is not None:
227
223
  out = jnp.concatenate([out, residual], axis=-1)
@@ -233,7 +229,6 @@ class Downsample(nn.Module):
233
229
  activation:Callable=jax.nn.swish
234
230
  dtype: Optional[Dtype] = None
235
231
  precision: PrecisionLike = None
236
- kernel_init:Callable=kernel_init(1.0)
237
232
 
238
233
  @nn.compact
239
234
  def __call__(self, x, residual=None):
@@ -244,7 +239,6 @@ class Downsample(nn.Module):
244
239
  strides=(2, 2),
245
240
  dtype=self.dtype,
246
241
  precision=self.precision,
247
- kernel_init=self.kernel_init
248
242
  )(x)
249
243
  if residual is not None:
250
244
  if residual.shape[1] > out.shape[1]:
@@ -269,7 +263,6 @@ class ResidualBlock(nn.Module):
269
263
  direction:str=None
270
264
  res:int=2
271
265
  norm_groups:int=8
272
- kernel_init:Callable=kernel_init(1.0)
273
266
  dtype: Optional[Dtype] = None
274
267
  precision: PrecisionLike = None
275
268
  named_norms:bool=False
@@ -296,7 +289,6 @@ class ResidualBlock(nn.Module):
296
289
  features=self.features,
297
290
  kernel_size=self.kernel_size,
298
291
  strides=self.strides,
299
- kernel_init=self.kernel_init,
300
292
  name="conv1",
301
293
  dtype=self.dtype,
302
294
  precision=self.precision
@@ -321,7 +313,6 @@ class ResidualBlock(nn.Module):
321
313
  features=self.features,
322
314
  kernel_size=self.kernel_size,
323
315
  strides=self.strides,
324
- kernel_init=self.kernel_init,
325
316
  name="conv2",
326
317
  dtype=self.dtype,
327
318
  precision=self.precision
@@ -333,7 +324,6 @@ class ResidualBlock(nn.Module):
333
324
  features=self.features,
334
325
  kernel_size=(1, 1),
335
326
  strides=1,
336
- kernel_init=self.kernel_init,
337
327
  name="residual_conv",
338
328
  dtype=self.dtype,
339
329
  precision=self.precision
@@ -10,17 +10,16 @@ from functools import partial
10
10
 
11
11
  class Unet(nn.Module):
12
12
  output_channels:int=3
13
- emb_features:int=64*4,
14
- feature_depths:list=[64, 128, 256, 512],
15
- attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}],
16
- num_res_blocks:int=2,
17
- num_middle_res_blocks:int=1,
13
+ emb_features:int=64*4
14
+ feature_depths:list=(64, 128, 256, 512)
15
+ attention_configs:list=({"heads":8}, {"heads":8}, {"heads":8}, {"heads":8})
16
+ num_res_blocks:int=2
17
+ num_middle_res_blocks:int=1
18
18
  activation:Callable = jax.nn.swish
19
19
  norm_groups:int=8
20
20
  dtype: Optional[Dtype] = None
21
21
  precision: PrecisionLike = None
22
22
  named_norms: bool = False # This is for backward compatibility reasons; older checkpoints have named norms
23
- kernel_init: Callable = partial(kernel_init, dtype=jnp.float32)
24
23
 
25
24
  def setup(self):
26
25
  if self.norm_groups > 0:
@@ -50,7 +49,6 @@ class Unet(nn.Module):
50
49
  features=self.feature_depths[0],
51
50
  kernel_size=(3, 3),
52
51
  strides=(1, 1),
53
- kernel_init=self.kernel_init(scale=1.0),
54
52
  dtype=self.dtype,
55
53
  precision=self.precision
56
54
  )(x)
@@ -65,7 +63,6 @@ class Unet(nn.Module):
65
63
  down_conv_type,
66
64
  name=f"down_{i}_residual_{j}",
67
65
  features=dim_in,
68
- kernel_init=self.kernel_init(scale=1.0),
69
66
  kernel_size=(3, 3),
70
67
  strides=(1, 1),
71
68
  activation=self.activation,
@@ -85,7 +82,6 @@ class Unet(nn.Module):
85
82
  force_fp32_for_softmax=attention_config.get("force_fp32_for_softmax", False),
86
83
  norm_inputs=attention_config.get("norm_inputs", True),
87
84
  explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
88
- kernel_init=self.kernel_init(scale=1.0),
89
85
  name=f"down_{i}_attention_{j}")(x, textcontext)
90
86
  # print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
91
87
  downs.append(x)
@@ -108,7 +104,6 @@ class Unet(nn.Module):
108
104
  middle_conv_type,
109
105
  name=f"middle_res1_{j}",
110
106
  features=middle_dim_out,
111
- kernel_init=self.kernel_init(scale=1.0),
112
107
  kernel_size=(3, 3),
113
108
  strides=(1, 1),
114
109
  activation=self.activation,
@@ -129,13 +124,11 @@ class Unet(nn.Module):
129
124
  force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
130
125
  norm_inputs=middle_attention.get("norm_inputs", True),
131
126
  explicitly_add_residual=middle_attention.get("explicitly_add_residual", True),
132
- kernel_init=self.kernel_init(scale=1.0),
133
127
  name=f"middle_attention_{j}")(x, textcontext)
134
128
  x = ResidualBlock(
135
129
  middle_conv_type,
136
130
  name=f"middle_res2_{j}",
137
131
  features=middle_dim_out,
138
- kernel_init=self.kernel_init(scale=1.0),
139
132
  kernel_size=(3, 3),
140
133
  strides=(1, 1),
141
134
  activation=self.activation,
@@ -157,7 +150,6 @@ class Unet(nn.Module):
157
150
  up_conv_type,# if j == 0 else "separable",
158
151
  name=f"up_{i}_residual_{j}",
159
152
  features=dim_out,
160
- kernel_init=self.kernel_init(scale=1.0),
161
153
  kernel_size=kernel_size,
162
154
  strides=(1, 1),
163
155
  activation=self.activation,
@@ -177,7 +169,6 @@ class Unet(nn.Module):
177
169
  force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
178
170
  norm_inputs=attention_config.get("norm_inputs", True),
179
171
  explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
180
- kernel_init=self.kernel_init(scale=1.0),
181
172
  name=f"up_{i}_attention_{j}")(x, textcontext)
182
173
  # print("Upscaling ", i, x.shape)
183
174
  if i != len(feature_depths) - 1:
@@ -196,7 +187,6 @@ class Unet(nn.Module):
196
187
  features=self.feature_depths[0],
197
188
  kernel_size=(3, 3),
198
189
  strides=(1, 1),
199
- kernel_init=self.kernel_init(scale=1.0),
200
190
  dtype=self.dtype,
201
191
  precision=self.precision
202
192
  )(x)
@@ -207,7 +197,6 @@ class Unet(nn.Module):
207
197
  conv_type,
208
198
  name="final_residual",
209
199
  features=self.feature_depths[0],
210
- kernel_init=self.kernel_init(scale=1.0),
211
200
  kernel_size=(3,3),
212
201
  strides=(1, 1),
213
202
  activation=self.activation,
@@ -226,7 +215,7 @@ class Unet(nn.Module):
226
215
  kernel_size=(3, 3),
227
216
  strides=(1, 1),
228
217
  # activation=jax.nn.mish
229
- kernel_init=self.kernel_init(scale=0.0),
218
+ # kernel_init=self.kernel_init(scale=0.0),
230
219
  dtype=self.dtype,
231
220
  precision=self.precision
232
221
  )(x)
@@ -23,7 +23,6 @@ class PatchEmbedding(nn.Module):
23
23
  embedding_dim: int
24
24
  dtype: Any = jnp.float32
25
25
  precision: Any = jax.lax.Precision.HIGH
26
- kernel_init: Callable = partial(kernel_init, 1.0)
27
26
 
28
27
  @nn.compact
29
28
  def __call__(self, x):
@@ -34,7 +33,6 @@ class PatchEmbedding(nn.Module):
34
33
  kernel_size=(self.patch_size, self.patch_size),
35
34
  strides=(self.patch_size, self.patch_size),
36
35
  dtype=self.dtype,
37
- kernel_init=self.kernel_init(),
38
36
  precision=self.precision)(x)
39
37
  x = jnp.reshape(x, (batch, -1, self.embedding_dim))
40
38
  return x
@@ -53,7 +51,7 @@ class PositionalEncoding(nn.Module):
53
51
  class UViT(nn.Module):
54
52
  output_channels:int=3
55
53
  patch_size: int = 16
56
- emb_features:int=768,
54
+ emb_features:int=768
57
55
  num_layers: int = 12
58
56
  num_heads: int = 12
59
57
  dropout_rate: float = 0.1
@@ -67,7 +65,7 @@ class UViT(nn.Module):
67
65
  norm_groups:int=8
68
66
  dtype: Optional[Dtype] = None
69
67
  precision: PrecisionLike = None
70
- kernel_init: Callable = partial(kernel_init, scale=1.0)
68
+ # kernel_init: Callable = partial(kernel_init, scale=1.0)
71
69
  add_residualblock_output: bool = False
72
70
  norm_inputs: bool = False
73
71
  explicitly_add_residual: bool = True
@@ -88,10 +86,10 @@ class UViT(nn.Module):
88
86
 
89
87
  # Patch embedding
90
88
  x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features,
91
- dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init)(x)
89
+ dtype=self.dtype, precision=self.precision)(x)
92
90
  num_patches = x.shape[1]
93
91
 
94
- context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
92
+ context_emb = nn.DenseGeneral(features=self.emb_features,
95
93
  dtype=self.dtype, precision=self.precision)(textcontext)
96
94
  num_text_tokens = textcontext.shape[1]
97
95
 
@@ -116,7 +114,7 @@ class UViT(nn.Module):
116
114
  only_pure_attention=False,
117
115
  norm_inputs=self.norm_inputs,
118
116
  explicitly_add_residual=self.explicitly_add_residual,
119
- kernel_init=self.kernel_init())(x)
117
+ )(x)
120
118
  skips.append(x)
121
119
 
122
120
  # Middle block
@@ -126,12 +124,12 @@ class UViT(nn.Module):
126
124
  only_pure_attention=False,
127
125
  norm_inputs=self.norm_inputs,
128
126
  explicitly_add_residual=self.explicitly_add_residual,
129
- kernel_init=self.kernel_init())(x)
127
+ )(x)
130
128
 
131
129
  # # Out blocks
132
130
  for i in range(self.num_layers // 2):
133
131
  x = jnp.concatenate([x, skips.pop()], axis=-1)
134
- x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
132
+ x = nn.DenseGeneral(features=self.emb_features,
135
133
  dtype=self.dtype, precision=self.precision)(x)
136
134
  x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
137
135
  dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
@@ -139,13 +137,13 @@ class UViT(nn.Module):
139
137
  only_pure_attention=False,
140
138
  norm_inputs=self.norm_inputs,
141
139
  explicitly_add_residual=self.explicitly_add_residual,
142
- kernel_init=self.kernel_init())(x)
140
+ )(x)
143
141
 
144
142
  # print(f'Shape of x after transformer blocks: {x.shape}')
145
143
  x = self.norm()(x)
146
144
 
147
145
  patch_dim = self.patch_size ** 2 * self.output_channels
148
- x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init())(x)
146
+ x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision)(x)
149
147
  x = x[:, 1 + num_text_tokens:, :]
150
148
  x = unpatchify(x, channels=self.output_channels)
151
149
 
@@ -159,7 +157,6 @@ class UViT(nn.Module):
159
157
  kernel_size=(3, 3),
160
158
  strides=(1, 1),
161
159
  # activation=jax.nn.mish
162
- kernel_init=self.kernel_init(scale=0.0),
163
160
  dtype=self.dtype,
164
161
  precision=self.precision
165
162
  )(x)
@@ -173,7 +170,6 @@ class UViT(nn.Module):
173
170
  kernel_size=(3, 3),
174
171
  strides=(1, 1),
175
172
  # activation=jax.nn.mish
176
- kernel_init=self.kernel_init(scale=0.0),
177
173
  dtype=self.dtype,
178
174
  precision=self.precision
179
175
  )(x)