flaxdiff 0.1.38__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. flaxdiff/data/__init__.py +5 -1
  2. flaxdiff/data/benchmark_decord.py +443 -0
  3. flaxdiff/data/dataloaders.py +608 -0
  4. flaxdiff/data/dataset_map.py +61 -6
  5. flaxdiff/data/online_loader.py +779 -150
  6. flaxdiff/data/sources/audio_utils.py +142 -0
  7. flaxdiff/data/sources/av_example.py +125 -0
  8. flaxdiff/data/sources/av_utils.py +590 -0
  9. flaxdiff/data/sources/base.py +129 -0
  10. flaxdiff/data/sources/images.py +309 -0
  11. flaxdiff/data/sources/utils.py +158 -0
  12. flaxdiff/data/sources/videos.py +250 -0
  13. flaxdiff/data/sources/voxceleb2.py +412 -0
  14. flaxdiff/inference/__init__.py +0 -0
  15. flaxdiff/inference/pipeline.py +260 -0
  16. flaxdiff/inference/utils.py +320 -0
  17. flaxdiff/inputs/__init__.py +173 -0
  18. flaxdiff/inputs/encoders.py +98 -0
  19. flaxdiff/models/__init__.py +2 -1
  20. flaxdiff/models/attention.py +22 -16
  21. flaxdiff/models/autoencoder/autoencoder.py +141 -9
  22. flaxdiff/models/autoencoder/diffusers.py +88 -25
  23. flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
  24. flaxdiff/models/common.py +8 -18
  25. flaxdiff/models/simple_unet.py +6 -17
  26. flaxdiff/models/simple_vit.py +9 -13
  27. flaxdiff/models/unet_3d.py +446 -0
  28. flaxdiff/models/unet_3d_blocks.py +505 -0
  29. flaxdiff/samplers/common.py +358 -96
  30. flaxdiff/samplers/ddim.py +44 -5
  31. flaxdiff/schedulers/karras.py +20 -12
  32. flaxdiff/trainer/__init__.py +2 -1
  33. flaxdiff/trainer/autoencoder_trainer.py +1 -2
  34. flaxdiff/trainer/diffusion_trainer.py +35 -29
  35. flaxdiff/trainer/general_diffusion_trainer.py +583 -0
  36. flaxdiff/trainer/simple_trainer.py +51 -16
  37. flaxdiff/utils.py +128 -57
  38. {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
  39. flaxdiff-0.2.0.dist-info/RECORD +64 -0
  40. {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
  41. flaxdiff/data/datasets.py +0 -169
  42. flaxdiff/data/sources/gcs.py +0 -81
  43. flaxdiff/data/sources/tfds.py +0 -79
  44. flaxdiff/trainer/video_diffusion_trainer.py +0 -62
  45. flaxdiff-0.1.38.dist-info/RECORD +0 -50
  46. {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,173 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import flax.struct as struct
4
+ import flax.linen as nn
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple
6
+ from dataclasses import dataclass
7
+ from functools import partial
8
+ import numpy as np
9
+ from jax.sharding import Mesh, PartitionSpec as P
10
+ from abc import ABC, abstractmethod
11
+
12
+ from flaxdiff.models.autoencoder import AutoEncoder
13
+ from .encoders import *
14
+
15
+ @dataclass
16
+ class ConditionalInputConfig:
17
+ """Class representing a conditional input for the model."""
18
+ encoder: ConditioningEncoder
19
+ conditioning_data_key: str = None # Key in the batch for this conditioning input
20
+ pretokenized: bool = False
21
+ unconditional_input: Any = None
22
+ model_key_override: Optional[str] = None # Optional key override for the model
23
+
24
+ __uncond_cache__ = None # Cache for unconditional input
25
+
26
+ def __post_init__(self):
27
+ if self.unconditional_input is not None:
28
+ uncond = self.encoder([self.unconditional_input])
29
+ else:
30
+ uncond = self.encoder([""]) # Default empty text
31
+ self.__uncond_cache__ = uncond # Cache the unconditional input
32
+
33
+ def __call__(self, batch_data):
34
+ """Process batch data to produce conditioning."""
35
+ key = self.conditioning_data_key if self.conditioning_data_key else self.encoder.key
36
+ if self.pretokenized:
37
+ return self.encoder.encode_from_tokens(batch_data[key])
38
+ return self.encoder(batch_data[key])
39
+
40
+ def get_unconditional(self):
41
+ """Get unconditional version of this input."""
42
+ return self.__uncond_cache__
43
+
44
+ def serialize(self):
45
+ """Serialize the configuration."""
46
+ serialized_config = {
47
+ "encoder": self.encoder.serialize(),
48
+ "encoder_key": self.encoder.key,
49
+ "conditioning_data_key": self.conditioning_data_key,
50
+ "unconditional_input": self.unconditional_input,
51
+ "model_key_override": self.model_key_override,
52
+ }
53
+ return serialized_config
54
+
55
+ @staticmethod
56
+ def deserialize(serialized_config):
57
+ """Deserialize the configuration."""
58
+ encoder_key = serialized_config["encoder_key"]
59
+ encoder_class = CONDITIONAL_ENCODERS_REGISTRY.get(encoder_key)
60
+ if encoder_class is None:
61
+ raise ValueError(f"Unknown encoder type: {encoder_key}")
62
+
63
+ # Create the encoder instance
64
+ encoder = encoder_class.deserialize(serialized_config["encoder"])
65
+ # Deserialize the rest of the configuration
66
+ conditioning_data_key = serialized_config.get("conditioning_data_key")
67
+ unconditional_input = serialized_config.get("unconditional_input")
68
+ model_key_override = serialized_config.get("model_key_override")
69
+ return ConditionalInputConfig(
70
+ encoder=encoder,
71
+ conditioning_data_key=conditioning_data_key,
72
+ unconditional_input=unconditional_input,
73
+ model_key_override=model_key_override,
74
+ )
75
+
76
+ @dataclass
77
+ class DiffusionInputConfig:
78
+ """Configuration for the input data."""
79
+ sample_data_key: str # Key in the batch for the sample data
80
+ sample_data_shape: Tuple[int, ...]
81
+ conditions: List[ConditionalInputConfig]
82
+
83
+ def get_input_shapes(
84
+ self,
85
+ autoencoder: AutoEncoder = None,
86
+ sample_model_key:str = 'x',
87
+ time_embeddings_model_key:str = 'temb',
88
+ ) -> Dict[str, Tuple[int, ...]]:
89
+ """Get the shapes of the input data."""
90
+ if len(self.sample_data_shape) == 3:
91
+ H, W, C = self.sample_data_shape
92
+ elif len(self.sample_data_shape) == 4:
93
+ T, H, W, C = self.sample_data_shape
94
+ else:
95
+ raise ValueError(f"Unsupported shape for sample data {self.sample_data_shape}")
96
+ if autoencoder is not None:
97
+ downscale_factor = autoencoder.downscale_factor
98
+ H = H // downscale_factor
99
+ W = W // downscale_factor
100
+ C = autoencoder.latent_channels
101
+
102
+ input_shapes = {
103
+ sample_model_key: (H, W, C),
104
+ time_embeddings_model_key: (),
105
+ }
106
+ for cond in self.conditions:
107
+ # Get the shape of the conditioning data by calling the get_unconditional method
108
+ unconditional = cond.get_unconditional()
109
+ key = cond.model_key_override if cond.model_key_override else cond.encoder.key
110
+ input_shapes[key] = unconditional[0].shape
111
+
112
+ print(f"Calculated input shapes: {input_shapes}")
113
+ return input_shapes
114
+
115
+ def get_unconditionals(self):
116
+ """Get unconditional inputs for all conditions."""
117
+ unconditionals = []
118
+ for cond in self.conditions:
119
+ uncond = cond.get_unconditional()
120
+ unconditionals.append(uncond)
121
+ return unconditionals
122
+
123
+ def process_conditioning(self, batch_data, uncond_mask: Optional[jnp.ndarray] = None):
124
+ """Process the conditioning data."""
125
+ results = []
126
+
127
+ for cond in self.conditions:
128
+ cond_embeddings = cond(batch_data)
129
+ if uncond_mask is not None:
130
+ assert len(uncond_mask) == len(cond_embeddings), "Unconditional mask length must match the batch size."
131
+ uncond_embedding = cond.get_unconditional()
132
+
133
+ # Reshape uncond_mask to be broadcastable with the conditioning embeddings
134
+ # If cond_embeddings has shape (B, T, D), reshape uncond_mask to (B, 1, 1)
135
+ broadcast_shape = [len(uncond_mask)] + [1] * (cond_embeddings.ndim - 1)
136
+ reshaped_mask = jnp.reshape(uncond_mask, broadcast_shape)
137
+
138
+ # Repeat uncond_embedding to match batch size
139
+ batch_size = len(cond_embeddings)
140
+ repeated_uncond = jnp.repeat(uncond_embedding, batch_size, axis=0)
141
+
142
+ # Apply unconditional embedding based on the mask
143
+ cond_embeddings = jnp.where(reshaped_mask, repeated_uncond, cond_embeddings)
144
+
145
+ results.append(cond_embeddings)
146
+ return results
147
+
148
+ def serialize(self):
149
+ """Serialize the configuration."""
150
+ serialized_config = {
151
+ "sample_data_key": self.sample_data_key,
152
+ "sample_data_shape": self.sample_data_shape,
153
+ "conditions": [cond.serialize() for cond in self.conditions],
154
+ }
155
+ return serialized_config
156
+
157
+ @staticmethod
158
+ def deserialize(serialized_config):
159
+ """Deserialize the configuration."""
160
+ sample_data_key = serialized_config["sample_data_key"]
161
+ sample_data_shape = tuple(serialized_config["sample_data_shape"])
162
+ conditions = serialized_config["conditions"]
163
+
164
+ # Deserialize each condition
165
+ deserialized_conditions = []
166
+ for cond in conditions:
167
+ deserialized_conditions.append(ConditionalInputConfig.deserialize(cond))
168
+
169
+ return DiffusionInputConfig(
170
+ sample_data_key=sample_data_key,
171
+ sample_data_shape=sample_data_shape,
172
+ conditions=deserialized_conditions,
173
+ )
@@ -0,0 +1,98 @@
1
+ import jax.numpy as jnp
2
+ import flax.linen as nn
3
+ from typing import Callable
4
+ from dataclasses import dataclass
5
+ from abc import ABC, abstractmethod
6
+
7
+ @dataclass
8
+ class ConditioningEncoder(ABC):
9
+ model: nn.Module
10
+ tokenizer: Callable
11
+
12
+ @property
13
+ def key(self):
14
+ name = self.tokenizer.__name__
15
+ # Remove the 'Encoder' suffix from the name and lowercase it
16
+ if name.endswith("Encoder"):
17
+ name = name[:-7].lower()
18
+ return name
19
+
20
+ def __call__(self, data):
21
+ tokens = self.tokenize(data)
22
+ outputs = self.encode_from_tokens(tokens)
23
+ return outputs
24
+
25
+ def encode_from_tokens(self, tokens):
26
+ outputs = self.model(input_ids=tokens['input_ids'],
27
+ attention_mask=tokens['attention_mask'])
28
+ last_hidden_state = outputs.last_hidden_state
29
+ return last_hidden_state
30
+
31
+ def tokenize(self, data):
32
+ tokens = self.tokenizer(data, padding="max_length",
33
+ max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np")
34
+ return tokens
35
+
36
+ @abstractmethod
37
+ def serialize(self):
38
+ """Serialize the encoder configuration."""
39
+ pass
40
+
41
+ @staticmethod
42
+ @abstractmethod
43
+ def deserialize(serialized_config):
44
+ """Deserialize the encoder configuration."""
45
+ pass
46
+
47
+ @dataclass
48
+ class TextEncoder(ConditioningEncoder):
49
+ """Text Encoder."""
50
+ @property
51
+ def key(self):
52
+ return "text"
53
+
54
+ @dataclass
55
+ class CLIPTextEncoder(TextEncoder):
56
+ """CLIP Text Encoder."""
57
+ modelname: str
58
+ backend: str
59
+
60
+ @staticmethod
61
+ def from_modelname(modelname: str = "openai/clip-vit-large-patch14", backend: str="jax"):
62
+ from transformers import (
63
+ CLIPTextModel,
64
+ FlaxCLIPTextModel,
65
+ AutoTokenizer,
66
+ )
67
+ modelname = "openai/clip-vit-large-patch14"
68
+ if backend == "jax":
69
+ model = FlaxCLIPTextModel.from_pretrained(
70
+ modelname, dtype=jnp.bfloat16)
71
+ else:
72
+ model = CLIPTextModel.from_pretrained(modelname)
73
+ tokenizer = AutoTokenizer.from_pretrained(modelname, dtype=jnp.float16)
74
+ return CLIPTextEncoder(
75
+ model=model,
76
+ tokenizer=tokenizer,
77
+ modelname=modelname,
78
+ backend=backend
79
+ )
80
+
81
+ def serialize(self):
82
+ """Serialize the encoder configuration."""
83
+ serialized_config = {
84
+ "modelname": self.modelname,
85
+ "backend": self.backend,
86
+ }
87
+ return serialized_config
88
+
89
+ @staticmethod
90
+ def deserialize(serialized_config):
91
+ """Deserialize the encoder configuration."""
92
+ modelname = serialized_config["modelname"]
93
+ backend = serialized_config["backend"]
94
+ return CLIPTextEncoder.from_modelname(modelname=modelname, backend=backend)
95
+
96
+ CONDITIONAL_ENCODERS_REGISTRY = {
97
+ "text": CLIPTextEncoder,
98
+ }
@@ -1 +1,2 @@
1
- from .simple_unet import *
1
+ from .simple_unet import *
2
+ # from .video_unet import FlaxUNet3DConditionModel, BCHWModelWrapper, FlaxTemporalConvLayer
@@ -23,7 +23,7 @@ class EfficientAttention(nn.Module):
23
23
  dtype: Optional[Dtype] = None
24
24
  precision: PrecisionLike = None
25
25
  use_bias: bool = True
26
- kernel_init: Callable = kernel_init(1.0)
26
+ # kernel_init: Callable = kernel_init(1.0)
27
27
  force_fp32_for_softmax: bool = True
28
28
 
29
29
  def setup(self):
@@ -34,15 +34,21 @@ class EfficientAttention(nn.Module):
34
34
  self.heads * self.dim_head,
35
35
  precision=self.precision,
36
36
  use_bias=self.use_bias,
37
- kernel_init=self.kernel_init,
37
+ # kernel_init=self.kernel_init,
38
38
  dtype=self.dtype
39
39
  )
40
40
  self.query = dense(name="to_q")
41
41
  self.key = dense(name="to_k")
42
42
  self.value = dense(name="to_v")
43
43
 
44
- self.proj_attn = nn.DenseGeneral(self.query_dim, use_bias=False, precision=self.precision,
45
- kernel_init=self.kernel_init, dtype=self.dtype, name="to_out_0")
44
+ self.proj_attn = nn.DenseGeneral(
45
+ self.query_dim,
46
+ use_bias=False,
47
+ precision=self.precision,
48
+ # kernel_init=self.kernel_init,
49
+ dtype=self.dtype,
50
+ name="to_out_0"
51
+ )
46
52
  # self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16)
47
53
 
48
54
  def _reshape_tensor_to_head_dim(self, tensor):
@@ -115,7 +121,7 @@ class NormalAttention(nn.Module):
115
121
  dtype: Optional[Dtype] = None
116
122
  precision: PrecisionLike = None
117
123
  use_bias: bool = True
118
- kernel_init: Callable = kernel_init(1.0)
124
+ # kernel_init: Callable = kernel_init(1.0)
119
125
  force_fp32_for_softmax: bool = True
120
126
 
121
127
  def setup(self):
@@ -126,7 +132,7 @@ class NormalAttention(nn.Module):
126
132
  axis=-1,
127
133
  precision=self.precision,
128
134
  use_bias=self.use_bias,
129
- kernel_init=self.kernel_init,
135
+ # kernel_init=self.kernel_init,
130
136
  dtype=self.dtype
131
137
  )
132
138
  self.query = dense(name="to_q")
@@ -140,7 +146,7 @@ class NormalAttention(nn.Module):
140
146
  use_bias=self.use_bias,
141
147
  dtype=self.dtype,
142
148
  name="to_out_0",
143
- kernel_init=self.kernel_init
149
+ # kernel_init=self.kernel_init
144
150
  # kernel_init=jax.nn.initializers.xavier_uniform()
145
151
  )
146
152
 
@@ -236,7 +242,7 @@ class BasicTransformerBlock(nn.Module):
236
242
  dtype: Optional[Dtype] = None
237
243
  precision: PrecisionLike = None
238
244
  use_bias: bool = True
239
- kernel_init: Callable = kernel_init(1.0)
245
+ # kernel_init: Callable = kernel_init(1.0)
240
246
  use_flash_attention:bool = False
241
247
  use_cross_only:bool = False
242
248
  only_pure_attention:bool = False
@@ -256,7 +262,7 @@ class BasicTransformerBlock(nn.Module):
256
262
  precision=self.precision,
257
263
  use_bias=self.use_bias,
258
264
  dtype=self.dtype,
259
- kernel_init=self.kernel_init,
265
+ # kernel_init=self.kernel_init,
260
266
  force_fp32_for_softmax=self.force_fp32_for_softmax
261
267
  )
262
268
  self.attention2 = attenBlock(
@@ -267,7 +273,7 @@ class BasicTransformerBlock(nn.Module):
267
273
  precision=self.precision,
268
274
  use_bias=self.use_bias,
269
275
  dtype=self.dtype,
270
- kernel_init=self.kernel_init,
276
+ # kernel_init=self.kernel_init,
271
277
  force_fp32_for_softmax=self.force_fp32_for_softmax
272
278
  )
273
279
 
@@ -303,7 +309,7 @@ class TransformerBlock(nn.Module):
303
309
  use_self_and_cross:bool = True
304
310
  only_pure_attention:bool = False
305
311
  force_fp32_for_softmax: bool = True
306
- kernel_init: Callable = kernel_init(1.0)
312
+ # kernel_init: Callable = kernel_init(1.0)
307
313
  norm_inputs: bool = True
308
314
  explicitly_add_residual: bool = True
309
315
 
@@ -317,12 +323,12 @@ class TransformerBlock(nn.Module):
317
323
  if self.use_linear_attention:
318
324
  projected_x = nn.Dense(features=inner_dim,
319
325
  use_bias=False, precision=self.precision,
320
- kernel_init=self.kernel_init,
326
+ # kernel_init=self.kernel_init,
321
327
  dtype=self.dtype, name=f'project_in')(x)
322
328
  else:
323
329
  projected_x = nn.Conv(
324
330
  features=inner_dim, kernel_size=(1, 1),
325
- kernel_init=self.kernel_init,
331
+ # kernel_init=self.kernel_init,
326
332
  strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
327
333
  precision=self.precision, name=f'project_in_conv',
328
334
  )(x)
@@ -344,19 +350,19 @@ class TransformerBlock(nn.Module):
344
350
  use_cross_only=(not self.use_self_and_cross),
345
351
  only_pure_attention=self.only_pure_attention,
346
352
  force_fp32_for_softmax=self.force_fp32_for_softmax,
347
- kernel_init=self.kernel_init
353
+ # kernel_init=self.kernel_init
348
354
  )(projected_x, context)
349
355
 
350
356
  if self.use_projection == True:
351
357
  if self.use_linear_attention:
352
358
  projected_x = nn.Dense(features=C, precision=self.precision,
353
359
  dtype=self.dtype, use_bias=False,
354
- kernel_init=self.kernel_init,
360
+ # kernel_init=self.kernel_init,
355
361
  name=f'project_out')(projected_x)
356
362
  else:
357
363
  projected_x = nn.Conv(
358
364
  features=C, kernel_size=(1, 1),
359
- kernel_init=self.kernel_init,
365
+ # kernel_init=self.kernel_i nit,
360
366
  strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
361
367
  precision=self.precision, name=f'project_out_conv',
362
368
  )(projected_x)
@@ -1,19 +1,151 @@
1
1
  import jax
2
2
  import jax.numpy as jnp
3
3
  from flax import linen as nn
4
- from typing import Dict, Callable, Sequence, Any, Union
4
+ from typing import Dict, Callable, Sequence, Any, Union, Optional
5
5
  import einops
6
6
  from ..common import kernel_init, ConvLayer, Upsample, Downsample, PixelShuffle
7
+ from dataclasses import dataclass
8
+ from abc import ABC, abstractmethod
7
9
 
8
-
9
- class AutoEncoder():
10
- def encode(self, x: jnp.ndarray, **kwargs) -> jnp.ndarray:
10
+ @dataclass
11
+ class AutoEncoder(ABC):
12
+ """Base class for autoencoder models with video support.
13
+
14
+ This class defines the interface for autoencoders and provides
15
+ video handling functionality, allowing child classes to focus
16
+ on implementing the core encoding/decoding for individual frames.
17
+ """
18
+ @abstractmethod
19
+ def __encode__(self, x: jnp.ndarray, **kwargs) -> jnp.ndarray:
20
+ """Abstract method for encoding a batch of images.
21
+
22
+ Child classes must implement this method to perform the actual encoding.
23
+
24
+ Args:
25
+ x: Input tensor of shape [B, H, W, C] (batch of images)
26
+ **kwargs: Additional arguments for the encoding process
27
+
28
+ Returns:
29
+ Encoded latent representation
30
+ """
31
+ raise NotImplementedError
32
+
33
+ @abstractmethod
34
+ def __decode__(self, z: jnp.ndarray, **kwargs) -> jnp.ndarray:
35
+ """Abstract method for decoding a batch of latents.
36
+
37
+ Child classes must implement this method to perform the actual decoding.
38
+
39
+ Args:
40
+ z: Latent tensor of shape [B, h, w, c] (encoded representation)
41
+ **kwargs: Additional arguments for the decoding process
42
+
43
+ Returns:
44
+ Decoded images
45
+ """
11
46
  raise NotImplementedError
12
47
 
13
- def decode(self, z: jnp.ndarray, **kwargs) -> jnp.ndarray:
48
+ def encode(self, x: jnp.ndarray, key: Optional[jax.random.PRNGKey] = None, **kwargs) -> jnp.ndarray:
49
+ """Encode input data, with special handling for video data.
50
+
51
+ This method handles both standard image batches and video data (5D tensors).
52
+ For videos, it reshapes the input, processes each frame, and then restores
53
+ the temporal dimension.
54
+
55
+ Args:
56
+ x: Input tensor, either [B, H, W, C] for images or [B, T, H, W, C] for videos
57
+ key: Optional random key for stochastic encoding
58
+ **kwargs: Additional arguments passed to __encode__
59
+
60
+ Returns:
61
+ Encoded representation with the same batch and temporal dimensions as input
62
+ """
63
+ # Check for video data (5D tensor)
64
+ is_video = len(x.shape) == 5
65
+
66
+ if is_video:
67
+ # Extract dimensions for reshaping
68
+ batch_size, seq_len, height, width, channels = x.shape
69
+
70
+ # Reshape to [B*T, H, W, C] to process as regular images
71
+ x_reshaped = x.reshape(-1, height, width, channels)
72
+
73
+ # Encode all frames
74
+ latent = self.__encode__(x_reshaped, key=key, **kwargs)
75
+
76
+ # Reshape back to include temporal dimension [B, T, h, w, c]
77
+ latent_shape = latent.shape
78
+ return latent.reshape(batch_size, seq_len, *latent_shape[1:])
79
+ else:
80
+ # Standard image processing
81
+ return self.__encode__(x, key=key, **kwargs)
82
+
83
+ def decode(self, z: jnp.ndarray, key: Optional[jax.random.PRNGKey] = None, **kwargs) -> jnp.ndarray:
84
+ """Decode latent representations, with special handling for video data.
85
+
86
+ This method handles both standard image latents and video latents (5D tensors).
87
+ For videos, it reshapes the input, processes each frame, and then restores
88
+ the temporal dimension.
89
+
90
+ Args:
91
+ z: Latent tensor, either [B, h, w, c] for images or [B, T, h, w, c] for videos
92
+ key: Optional random key for stochastic decoding
93
+ **kwargs: Additional arguments passed to __decode__
94
+
95
+ Returns:
96
+ Decoded output with the same batch and temporal dimensions as input
97
+ """
98
+ # Check for video data (5D tensor)
99
+ is_video = len(z.shape) == 5
100
+
101
+ if is_video:
102
+ # Extract dimensions for reshaping
103
+ batch_size, seq_len, height, width, channels = z.shape
104
+
105
+ # Reshape to [B*T, h, w, c] to process as regular latents
106
+ z_reshaped = z.reshape(-1, height, width, channels)
107
+
108
+ # Decode all frames
109
+ decoded = self.__decode__(z_reshaped, key=key, **kwargs)
110
+
111
+ # Reshape back to include temporal dimension [B, T, H, W, C]
112
+ decoded_shape = decoded.shape
113
+ return decoded.reshape(batch_size, seq_len, *decoded_shape[1:])
114
+ else:
115
+ # Standard latent processing
116
+ return self.__decode__(z, key=key, **kwargs)
117
+
118
+ def __call__(self, x: jnp.ndarray, key: Optional[jax.random.PRNGKey] = None, **kwargs):
119
+ """Encode and then decode the input (autoencoder).
120
+
121
+ Args:
122
+ x: Input tensor, either [B, H, W, C] for images or [B, T, H, W, C] for videos
123
+ key: Optional random key for stochastic encoding/decoding
124
+ **kwargs: Additional arguments for encoding and decoding
125
+
126
+ Returns:
127
+ Reconstructed output with the same dimensions as input
128
+ """
129
+ if key is not None:
130
+ encode_key, decode_key = jax.random.split(key)
131
+ else:
132
+ encode_key = decode_key = None
133
+
134
+ # Encode then decode
135
+ z = self.encode(x, key=encode_key, **kwargs)
136
+ return self.decode(z, key=decode_key, **kwargs)
137
+
138
+ @property
139
+ def spatial_scale(self) -> int:
140
+ """Get the spatial scale factor between input and latent spaces."""
141
+ return getattr(self, "_spatial_scale", None)
142
+
143
+ @property
144
+ def name(self) -> str:
145
+ """Get the name of the autoencoder model."""
14
146
  raise NotImplementedError
15
147
 
16
- def __call__(self, x: jnp.ndarray):
17
- latents = self.encode(x)
18
- reconstructions = self.decode(latents)
19
- return reconstructions
148
+ @abstractmethod
149
+ def serialize(self) -> Dict[str, Any]:
150
+ """Serialize the model parameters and configuration."""
151
+ raise NotImplementedError