flaxdiff 0.1.38__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/__init__.py +5 -1
- flaxdiff/data/benchmark_decord.py +443 -0
- flaxdiff/data/dataloaders.py +608 -0
- flaxdiff/data/dataset_map.py +61 -6
- flaxdiff/data/online_loader.py +779 -150
- flaxdiff/data/sources/audio_utils.py +142 -0
- flaxdiff/data/sources/av_example.py +125 -0
- flaxdiff/data/sources/av_utils.py +590 -0
- flaxdiff/data/sources/base.py +129 -0
- flaxdiff/data/sources/images.py +309 -0
- flaxdiff/data/sources/utils.py +158 -0
- flaxdiff/data/sources/videos.py +250 -0
- flaxdiff/data/sources/voxceleb2.py +412 -0
- flaxdiff/inference/__init__.py +0 -0
- flaxdiff/inference/pipeline.py +260 -0
- flaxdiff/inference/utils.py +320 -0
- flaxdiff/inputs/__init__.py +173 -0
- flaxdiff/inputs/encoders.py +98 -0
- flaxdiff/models/__init__.py +2 -1
- flaxdiff/models/attention.py +22 -16
- flaxdiff/models/autoencoder/autoencoder.py +141 -9
- flaxdiff/models/autoencoder/diffusers.py +88 -25
- flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
- flaxdiff/models/common.py +8 -18
- flaxdiff/models/simple_unet.py +6 -17
- flaxdiff/models/simple_vit.py +9 -13
- flaxdiff/models/unet_3d.py +446 -0
- flaxdiff/models/unet_3d_blocks.py +505 -0
- flaxdiff/samplers/common.py +358 -96
- flaxdiff/samplers/ddim.py +44 -5
- flaxdiff/schedulers/karras.py +20 -12
- flaxdiff/trainer/__init__.py +2 -1
- flaxdiff/trainer/autoencoder_trainer.py +1 -2
- flaxdiff/trainer/diffusion_trainer.py +35 -29
- flaxdiff/trainer/general_diffusion_trainer.py +583 -0
- flaxdiff/trainer/simple_trainer.py +51 -16
- flaxdiff/utils.py +128 -57
- {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
- flaxdiff-0.2.0.dist-info/RECORD +64 -0
- {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
- flaxdiff/data/datasets.py +0 -169
- flaxdiff/data/sources/gcs.py +0 -81
- flaxdiff/data/sources/tfds.py +0 -79
- flaxdiff/trainer/video_diffusion_trainer.py +0 -62
- flaxdiff-0.1.38.dist-info/RECORD +0 -50
- {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,173 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
import flax.struct as struct
|
4
|
+
import flax.linen as nn
|
5
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
6
|
+
from dataclasses import dataclass
|
7
|
+
from functools import partial
|
8
|
+
import numpy as np
|
9
|
+
from jax.sharding import Mesh, PartitionSpec as P
|
10
|
+
from abc import ABC, abstractmethod
|
11
|
+
|
12
|
+
from flaxdiff.models.autoencoder import AutoEncoder
|
13
|
+
from .encoders import *
|
14
|
+
|
15
|
+
@dataclass
|
16
|
+
class ConditionalInputConfig:
|
17
|
+
"""Class representing a conditional input for the model."""
|
18
|
+
encoder: ConditioningEncoder
|
19
|
+
conditioning_data_key: str = None # Key in the batch for this conditioning input
|
20
|
+
pretokenized: bool = False
|
21
|
+
unconditional_input: Any = None
|
22
|
+
model_key_override: Optional[str] = None # Optional key override for the model
|
23
|
+
|
24
|
+
__uncond_cache__ = None # Cache for unconditional input
|
25
|
+
|
26
|
+
def __post_init__(self):
|
27
|
+
if self.unconditional_input is not None:
|
28
|
+
uncond = self.encoder([self.unconditional_input])
|
29
|
+
else:
|
30
|
+
uncond = self.encoder([""]) # Default empty text
|
31
|
+
self.__uncond_cache__ = uncond # Cache the unconditional input
|
32
|
+
|
33
|
+
def __call__(self, batch_data):
|
34
|
+
"""Process batch data to produce conditioning."""
|
35
|
+
key = self.conditioning_data_key if self.conditioning_data_key else self.encoder.key
|
36
|
+
if self.pretokenized:
|
37
|
+
return self.encoder.encode_from_tokens(batch_data[key])
|
38
|
+
return self.encoder(batch_data[key])
|
39
|
+
|
40
|
+
def get_unconditional(self):
|
41
|
+
"""Get unconditional version of this input."""
|
42
|
+
return self.__uncond_cache__
|
43
|
+
|
44
|
+
def serialize(self):
|
45
|
+
"""Serialize the configuration."""
|
46
|
+
serialized_config = {
|
47
|
+
"encoder": self.encoder.serialize(),
|
48
|
+
"encoder_key": self.encoder.key,
|
49
|
+
"conditioning_data_key": self.conditioning_data_key,
|
50
|
+
"unconditional_input": self.unconditional_input,
|
51
|
+
"model_key_override": self.model_key_override,
|
52
|
+
}
|
53
|
+
return serialized_config
|
54
|
+
|
55
|
+
@staticmethod
|
56
|
+
def deserialize(serialized_config):
|
57
|
+
"""Deserialize the configuration."""
|
58
|
+
encoder_key = serialized_config["encoder_key"]
|
59
|
+
encoder_class = CONDITIONAL_ENCODERS_REGISTRY.get(encoder_key)
|
60
|
+
if encoder_class is None:
|
61
|
+
raise ValueError(f"Unknown encoder type: {encoder_key}")
|
62
|
+
|
63
|
+
# Create the encoder instance
|
64
|
+
encoder = encoder_class.deserialize(serialized_config["encoder"])
|
65
|
+
# Deserialize the rest of the configuration
|
66
|
+
conditioning_data_key = serialized_config.get("conditioning_data_key")
|
67
|
+
unconditional_input = serialized_config.get("unconditional_input")
|
68
|
+
model_key_override = serialized_config.get("model_key_override")
|
69
|
+
return ConditionalInputConfig(
|
70
|
+
encoder=encoder,
|
71
|
+
conditioning_data_key=conditioning_data_key,
|
72
|
+
unconditional_input=unconditional_input,
|
73
|
+
model_key_override=model_key_override,
|
74
|
+
)
|
75
|
+
|
76
|
+
@dataclass
|
77
|
+
class DiffusionInputConfig:
|
78
|
+
"""Configuration for the input data."""
|
79
|
+
sample_data_key: str # Key in the batch for the sample data
|
80
|
+
sample_data_shape: Tuple[int, ...]
|
81
|
+
conditions: List[ConditionalInputConfig]
|
82
|
+
|
83
|
+
def get_input_shapes(
|
84
|
+
self,
|
85
|
+
autoencoder: AutoEncoder = None,
|
86
|
+
sample_model_key:str = 'x',
|
87
|
+
time_embeddings_model_key:str = 'temb',
|
88
|
+
) -> Dict[str, Tuple[int, ...]]:
|
89
|
+
"""Get the shapes of the input data."""
|
90
|
+
if len(self.sample_data_shape) == 3:
|
91
|
+
H, W, C = self.sample_data_shape
|
92
|
+
elif len(self.sample_data_shape) == 4:
|
93
|
+
T, H, W, C = self.sample_data_shape
|
94
|
+
else:
|
95
|
+
raise ValueError(f"Unsupported shape for sample data {self.sample_data_shape}")
|
96
|
+
if autoencoder is not None:
|
97
|
+
downscale_factor = autoencoder.downscale_factor
|
98
|
+
H = H // downscale_factor
|
99
|
+
W = W // downscale_factor
|
100
|
+
C = autoencoder.latent_channels
|
101
|
+
|
102
|
+
input_shapes = {
|
103
|
+
sample_model_key: (H, W, C),
|
104
|
+
time_embeddings_model_key: (),
|
105
|
+
}
|
106
|
+
for cond in self.conditions:
|
107
|
+
# Get the shape of the conditioning data by calling the get_unconditional method
|
108
|
+
unconditional = cond.get_unconditional()
|
109
|
+
key = cond.model_key_override if cond.model_key_override else cond.encoder.key
|
110
|
+
input_shapes[key] = unconditional[0].shape
|
111
|
+
|
112
|
+
print(f"Calculated input shapes: {input_shapes}")
|
113
|
+
return input_shapes
|
114
|
+
|
115
|
+
def get_unconditionals(self):
|
116
|
+
"""Get unconditional inputs for all conditions."""
|
117
|
+
unconditionals = []
|
118
|
+
for cond in self.conditions:
|
119
|
+
uncond = cond.get_unconditional()
|
120
|
+
unconditionals.append(uncond)
|
121
|
+
return unconditionals
|
122
|
+
|
123
|
+
def process_conditioning(self, batch_data, uncond_mask: Optional[jnp.ndarray] = None):
|
124
|
+
"""Process the conditioning data."""
|
125
|
+
results = []
|
126
|
+
|
127
|
+
for cond in self.conditions:
|
128
|
+
cond_embeddings = cond(batch_data)
|
129
|
+
if uncond_mask is not None:
|
130
|
+
assert len(uncond_mask) == len(cond_embeddings), "Unconditional mask length must match the batch size."
|
131
|
+
uncond_embedding = cond.get_unconditional()
|
132
|
+
|
133
|
+
# Reshape uncond_mask to be broadcastable with the conditioning embeddings
|
134
|
+
# If cond_embeddings has shape (B, T, D), reshape uncond_mask to (B, 1, 1)
|
135
|
+
broadcast_shape = [len(uncond_mask)] + [1] * (cond_embeddings.ndim - 1)
|
136
|
+
reshaped_mask = jnp.reshape(uncond_mask, broadcast_shape)
|
137
|
+
|
138
|
+
# Repeat uncond_embedding to match batch size
|
139
|
+
batch_size = len(cond_embeddings)
|
140
|
+
repeated_uncond = jnp.repeat(uncond_embedding, batch_size, axis=0)
|
141
|
+
|
142
|
+
# Apply unconditional embedding based on the mask
|
143
|
+
cond_embeddings = jnp.where(reshaped_mask, repeated_uncond, cond_embeddings)
|
144
|
+
|
145
|
+
results.append(cond_embeddings)
|
146
|
+
return results
|
147
|
+
|
148
|
+
def serialize(self):
|
149
|
+
"""Serialize the configuration."""
|
150
|
+
serialized_config = {
|
151
|
+
"sample_data_key": self.sample_data_key,
|
152
|
+
"sample_data_shape": self.sample_data_shape,
|
153
|
+
"conditions": [cond.serialize() for cond in self.conditions],
|
154
|
+
}
|
155
|
+
return serialized_config
|
156
|
+
|
157
|
+
@staticmethod
|
158
|
+
def deserialize(serialized_config):
|
159
|
+
"""Deserialize the configuration."""
|
160
|
+
sample_data_key = serialized_config["sample_data_key"]
|
161
|
+
sample_data_shape = tuple(serialized_config["sample_data_shape"])
|
162
|
+
conditions = serialized_config["conditions"]
|
163
|
+
|
164
|
+
# Deserialize each condition
|
165
|
+
deserialized_conditions = []
|
166
|
+
for cond in conditions:
|
167
|
+
deserialized_conditions.append(ConditionalInputConfig.deserialize(cond))
|
168
|
+
|
169
|
+
return DiffusionInputConfig(
|
170
|
+
sample_data_key=sample_data_key,
|
171
|
+
sample_data_shape=sample_data_shape,
|
172
|
+
conditions=deserialized_conditions,
|
173
|
+
)
|
@@ -0,0 +1,98 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
import flax.linen as nn
|
3
|
+
from typing import Callable
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from abc import ABC, abstractmethod
|
6
|
+
|
7
|
+
@dataclass
|
8
|
+
class ConditioningEncoder(ABC):
|
9
|
+
model: nn.Module
|
10
|
+
tokenizer: Callable
|
11
|
+
|
12
|
+
@property
|
13
|
+
def key(self):
|
14
|
+
name = self.tokenizer.__name__
|
15
|
+
# Remove the 'Encoder' suffix from the name and lowercase it
|
16
|
+
if name.endswith("Encoder"):
|
17
|
+
name = name[:-7].lower()
|
18
|
+
return name
|
19
|
+
|
20
|
+
def __call__(self, data):
|
21
|
+
tokens = self.tokenize(data)
|
22
|
+
outputs = self.encode_from_tokens(tokens)
|
23
|
+
return outputs
|
24
|
+
|
25
|
+
def encode_from_tokens(self, tokens):
|
26
|
+
outputs = self.model(input_ids=tokens['input_ids'],
|
27
|
+
attention_mask=tokens['attention_mask'])
|
28
|
+
last_hidden_state = outputs.last_hidden_state
|
29
|
+
return last_hidden_state
|
30
|
+
|
31
|
+
def tokenize(self, data):
|
32
|
+
tokens = self.tokenizer(data, padding="max_length",
|
33
|
+
max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np")
|
34
|
+
return tokens
|
35
|
+
|
36
|
+
@abstractmethod
|
37
|
+
def serialize(self):
|
38
|
+
"""Serialize the encoder configuration."""
|
39
|
+
pass
|
40
|
+
|
41
|
+
@staticmethod
|
42
|
+
@abstractmethod
|
43
|
+
def deserialize(serialized_config):
|
44
|
+
"""Deserialize the encoder configuration."""
|
45
|
+
pass
|
46
|
+
|
47
|
+
@dataclass
|
48
|
+
class TextEncoder(ConditioningEncoder):
|
49
|
+
"""Text Encoder."""
|
50
|
+
@property
|
51
|
+
def key(self):
|
52
|
+
return "text"
|
53
|
+
|
54
|
+
@dataclass
|
55
|
+
class CLIPTextEncoder(TextEncoder):
|
56
|
+
"""CLIP Text Encoder."""
|
57
|
+
modelname: str
|
58
|
+
backend: str
|
59
|
+
|
60
|
+
@staticmethod
|
61
|
+
def from_modelname(modelname: str = "openai/clip-vit-large-patch14", backend: str="jax"):
|
62
|
+
from transformers import (
|
63
|
+
CLIPTextModel,
|
64
|
+
FlaxCLIPTextModel,
|
65
|
+
AutoTokenizer,
|
66
|
+
)
|
67
|
+
modelname = "openai/clip-vit-large-patch14"
|
68
|
+
if backend == "jax":
|
69
|
+
model = FlaxCLIPTextModel.from_pretrained(
|
70
|
+
modelname, dtype=jnp.bfloat16)
|
71
|
+
else:
|
72
|
+
model = CLIPTextModel.from_pretrained(modelname)
|
73
|
+
tokenizer = AutoTokenizer.from_pretrained(modelname, dtype=jnp.float16)
|
74
|
+
return CLIPTextEncoder(
|
75
|
+
model=model,
|
76
|
+
tokenizer=tokenizer,
|
77
|
+
modelname=modelname,
|
78
|
+
backend=backend
|
79
|
+
)
|
80
|
+
|
81
|
+
def serialize(self):
|
82
|
+
"""Serialize the encoder configuration."""
|
83
|
+
serialized_config = {
|
84
|
+
"modelname": self.modelname,
|
85
|
+
"backend": self.backend,
|
86
|
+
}
|
87
|
+
return serialized_config
|
88
|
+
|
89
|
+
@staticmethod
|
90
|
+
def deserialize(serialized_config):
|
91
|
+
"""Deserialize the encoder configuration."""
|
92
|
+
modelname = serialized_config["modelname"]
|
93
|
+
backend = serialized_config["backend"]
|
94
|
+
return CLIPTextEncoder.from_modelname(modelname=modelname, backend=backend)
|
95
|
+
|
96
|
+
CONDITIONAL_ENCODERS_REGISTRY = {
|
97
|
+
"text": CLIPTextEncoder,
|
98
|
+
}
|
flaxdiff/models/__init__.py
CHANGED
@@ -1 +1,2 @@
|
|
1
|
-
from .simple_unet import *
|
1
|
+
from .simple_unet import *
|
2
|
+
# from .video_unet import FlaxUNet3DConditionModel, BCHWModelWrapper, FlaxTemporalConvLayer
|
flaxdiff/models/attention.py
CHANGED
@@ -23,7 +23,7 @@ class EfficientAttention(nn.Module):
|
|
23
23
|
dtype: Optional[Dtype] = None
|
24
24
|
precision: PrecisionLike = None
|
25
25
|
use_bias: bool = True
|
26
|
-
kernel_init: Callable = kernel_init(1.0)
|
26
|
+
# kernel_init: Callable = kernel_init(1.0)
|
27
27
|
force_fp32_for_softmax: bool = True
|
28
28
|
|
29
29
|
def setup(self):
|
@@ -34,15 +34,21 @@ class EfficientAttention(nn.Module):
|
|
34
34
|
self.heads * self.dim_head,
|
35
35
|
precision=self.precision,
|
36
36
|
use_bias=self.use_bias,
|
37
|
-
kernel_init=self.kernel_init,
|
37
|
+
# kernel_init=self.kernel_init,
|
38
38
|
dtype=self.dtype
|
39
39
|
)
|
40
40
|
self.query = dense(name="to_q")
|
41
41
|
self.key = dense(name="to_k")
|
42
42
|
self.value = dense(name="to_v")
|
43
43
|
|
44
|
-
self.proj_attn = nn.DenseGeneral(
|
45
|
-
|
44
|
+
self.proj_attn = nn.DenseGeneral(
|
45
|
+
self.query_dim,
|
46
|
+
use_bias=False,
|
47
|
+
precision=self.precision,
|
48
|
+
# kernel_init=self.kernel_init,
|
49
|
+
dtype=self.dtype,
|
50
|
+
name="to_out_0"
|
51
|
+
)
|
46
52
|
# self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16)
|
47
53
|
|
48
54
|
def _reshape_tensor_to_head_dim(self, tensor):
|
@@ -115,7 +121,7 @@ class NormalAttention(nn.Module):
|
|
115
121
|
dtype: Optional[Dtype] = None
|
116
122
|
precision: PrecisionLike = None
|
117
123
|
use_bias: bool = True
|
118
|
-
kernel_init: Callable = kernel_init(1.0)
|
124
|
+
# kernel_init: Callable = kernel_init(1.0)
|
119
125
|
force_fp32_for_softmax: bool = True
|
120
126
|
|
121
127
|
def setup(self):
|
@@ -126,7 +132,7 @@ class NormalAttention(nn.Module):
|
|
126
132
|
axis=-1,
|
127
133
|
precision=self.precision,
|
128
134
|
use_bias=self.use_bias,
|
129
|
-
kernel_init=self.kernel_init,
|
135
|
+
# kernel_init=self.kernel_init,
|
130
136
|
dtype=self.dtype
|
131
137
|
)
|
132
138
|
self.query = dense(name="to_q")
|
@@ -140,7 +146,7 @@ class NormalAttention(nn.Module):
|
|
140
146
|
use_bias=self.use_bias,
|
141
147
|
dtype=self.dtype,
|
142
148
|
name="to_out_0",
|
143
|
-
kernel_init=self.kernel_init
|
149
|
+
# kernel_init=self.kernel_init
|
144
150
|
# kernel_init=jax.nn.initializers.xavier_uniform()
|
145
151
|
)
|
146
152
|
|
@@ -236,7 +242,7 @@ class BasicTransformerBlock(nn.Module):
|
|
236
242
|
dtype: Optional[Dtype] = None
|
237
243
|
precision: PrecisionLike = None
|
238
244
|
use_bias: bool = True
|
239
|
-
kernel_init: Callable = kernel_init(1.0)
|
245
|
+
# kernel_init: Callable = kernel_init(1.0)
|
240
246
|
use_flash_attention:bool = False
|
241
247
|
use_cross_only:bool = False
|
242
248
|
only_pure_attention:bool = False
|
@@ -256,7 +262,7 @@ class BasicTransformerBlock(nn.Module):
|
|
256
262
|
precision=self.precision,
|
257
263
|
use_bias=self.use_bias,
|
258
264
|
dtype=self.dtype,
|
259
|
-
kernel_init=self.kernel_init,
|
265
|
+
# kernel_init=self.kernel_init,
|
260
266
|
force_fp32_for_softmax=self.force_fp32_for_softmax
|
261
267
|
)
|
262
268
|
self.attention2 = attenBlock(
|
@@ -267,7 +273,7 @@ class BasicTransformerBlock(nn.Module):
|
|
267
273
|
precision=self.precision,
|
268
274
|
use_bias=self.use_bias,
|
269
275
|
dtype=self.dtype,
|
270
|
-
kernel_init=self.kernel_init,
|
276
|
+
# kernel_init=self.kernel_init,
|
271
277
|
force_fp32_for_softmax=self.force_fp32_for_softmax
|
272
278
|
)
|
273
279
|
|
@@ -303,7 +309,7 @@ class TransformerBlock(nn.Module):
|
|
303
309
|
use_self_and_cross:bool = True
|
304
310
|
only_pure_attention:bool = False
|
305
311
|
force_fp32_for_softmax: bool = True
|
306
|
-
kernel_init: Callable = kernel_init(1.0)
|
312
|
+
# kernel_init: Callable = kernel_init(1.0)
|
307
313
|
norm_inputs: bool = True
|
308
314
|
explicitly_add_residual: bool = True
|
309
315
|
|
@@ -317,12 +323,12 @@ class TransformerBlock(nn.Module):
|
|
317
323
|
if self.use_linear_attention:
|
318
324
|
projected_x = nn.Dense(features=inner_dim,
|
319
325
|
use_bias=False, precision=self.precision,
|
320
|
-
|
326
|
+
# kernel_init=self.kernel_init,
|
321
327
|
dtype=self.dtype, name=f'project_in')(x)
|
322
328
|
else:
|
323
329
|
projected_x = nn.Conv(
|
324
330
|
features=inner_dim, kernel_size=(1, 1),
|
325
|
-
kernel_init=self.kernel_init,
|
331
|
+
# kernel_init=self.kernel_init,
|
326
332
|
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
327
333
|
precision=self.precision, name=f'project_in_conv',
|
328
334
|
)(x)
|
@@ -344,19 +350,19 @@ class TransformerBlock(nn.Module):
|
|
344
350
|
use_cross_only=(not self.use_self_and_cross),
|
345
351
|
only_pure_attention=self.only_pure_attention,
|
346
352
|
force_fp32_for_softmax=self.force_fp32_for_softmax,
|
347
|
-
kernel_init=self.kernel_init
|
353
|
+
# kernel_init=self.kernel_init
|
348
354
|
)(projected_x, context)
|
349
355
|
|
350
356
|
if self.use_projection == True:
|
351
357
|
if self.use_linear_attention:
|
352
358
|
projected_x = nn.Dense(features=C, precision=self.precision,
|
353
359
|
dtype=self.dtype, use_bias=False,
|
354
|
-
|
360
|
+
# kernel_init=self.kernel_init,
|
355
361
|
name=f'project_out')(projected_x)
|
356
362
|
else:
|
357
363
|
projected_x = nn.Conv(
|
358
364
|
features=C, kernel_size=(1, 1),
|
359
|
-
kernel_init=self.
|
365
|
+
# kernel_init=self.kernel_i nit,
|
360
366
|
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
361
367
|
precision=self.precision, name=f'project_out_conv',
|
362
368
|
)(projected_x)
|
@@ -1,19 +1,151 @@
|
|
1
1
|
import jax
|
2
2
|
import jax.numpy as jnp
|
3
3
|
from flax import linen as nn
|
4
|
-
from typing import Dict, Callable, Sequence, Any, Union
|
4
|
+
from typing import Dict, Callable, Sequence, Any, Union, Optional
|
5
5
|
import einops
|
6
6
|
from ..common import kernel_init, ConvLayer, Upsample, Downsample, PixelShuffle
|
7
|
+
from dataclasses import dataclass
|
8
|
+
from abc import ABC, abstractmethod
|
7
9
|
|
8
|
-
|
9
|
-
class AutoEncoder():
|
10
|
-
|
10
|
+
@dataclass
|
11
|
+
class AutoEncoder(ABC):
|
12
|
+
"""Base class for autoencoder models with video support.
|
13
|
+
|
14
|
+
This class defines the interface for autoencoders and provides
|
15
|
+
video handling functionality, allowing child classes to focus
|
16
|
+
on implementing the core encoding/decoding for individual frames.
|
17
|
+
"""
|
18
|
+
@abstractmethod
|
19
|
+
def __encode__(self, x: jnp.ndarray, **kwargs) -> jnp.ndarray:
|
20
|
+
"""Abstract method for encoding a batch of images.
|
21
|
+
|
22
|
+
Child classes must implement this method to perform the actual encoding.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
x: Input tensor of shape [B, H, W, C] (batch of images)
|
26
|
+
**kwargs: Additional arguments for the encoding process
|
27
|
+
|
28
|
+
Returns:
|
29
|
+
Encoded latent representation
|
30
|
+
"""
|
31
|
+
raise NotImplementedError
|
32
|
+
|
33
|
+
@abstractmethod
|
34
|
+
def __decode__(self, z: jnp.ndarray, **kwargs) -> jnp.ndarray:
|
35
|
+
"""Abstract method for decoding a batch of latents.
|
36
|
+
|
37
|
+
Child classes must implement this method to perform the actual decoding.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
z: Latent tensor of shape [B, h, w, c] (encoded representation)
|
41
|
+
**kwargs: Additional arguments for the decoding process
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
Decoded images
|
45
|
+
"""
|
11
46
|
raise NotImplementedError
|
12
47
|
|
13
|
-
def
|
48
|
+
def encode(self, x: jnp.ndarray, key: Optional[jax.random.PRNGKey] = None, **kwargs) -> jnp.ndarray:
|
49
|
+
"""Encode input data, with special handling for video data.
|
50
|
+
|
51
|
+
This method handles both standard image batches and video data (5D tensors).
|
52
|
+
For videos, it reshapes the input, processes each frame, and then restores
|
53
|
+
the temporal dimension.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
x: Input tensor, either [B, H, W, C] for images or [B, T, H, W, C] for videos
|
57
|
+
key: Optional random key for stochastic encoding
|
58
|
+
**kwargs: Additional arguments passed to __encode__
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
Encoded representation with the same batch and temporal dimensions as input
|
62
|
+
"""
|
63
|
+
# Check for video data (5D tensor)
|
64
|
+
is_video = len(x.shape) == 5
|
65
|
+
|
66
|
+
if is_video:
|
67
|
+
# Extract dimensions for reshaping
|
68
|
+
batch_size, seq_len, height, width, channels = x.shape
|
69
|
+
|
70
|
+
# Reshape to [B*T, H, W, C] to process as regular images
|
71
|
+
x_reshaped = x.reshape(-1, height, width, channels)
|
72
|
+
|
73
|
+
# Encode all frames
|
74
|
+
latent = self.__encode__(x_reshaped, key=key, **kwargs)
|
75
|
+
|
76
|
+
# Reshape back to include temporal dimension [B, T, h, w, c]
|
77
|
+
latent_shape = latent.shape
|
78
|
+
return latent.reshape(batch_size, seq_len, *latent_shape[1:])
|
79
|
+
else:
|
80
|
+
# Standard image processing
|
81
|
+
return self.__encode__(x, key=key, **kwargs)
|
82
|
+
|
83
|
+
def decode(self, z: jnp.ndarray, key: Optional[jax.random.PRNGKey] = None, **kwargs) -> jnp.ndarray:
|
84
|
+
"""Decode latent representations, with special handling for video data.
|
85
|
+
|
86
|
+
This method handles both standard image latents and video latents (5D tensors).
|
87
|
+
For videos, it reshapes the input, processes each frame, and then restores
|
88
|
+
the temporal dimension.
|
89
|
+
|
90
|
+
Args:
|
91
|
+
z: Latent tensor, either [B, h, w, c] for images or [B, T, h, w, c] for videos
|
92
|
+
key: Optional random key for stochastic decoding
|
93
|
+
**kwargs: Additional arguments passed to __decode__
|
94
|
+
|
95
|
+
Returns:
|
96
|
+
Decoded output with the same batch and temporal dimensions as input
|
97
|
+
"""
|
98
|
+
# Check for video data (5D tensor)
|
99
|
+
is_video = len(z.shape) == 5
|
100
|
+
|
101
|
+
if is_video:
|
102
|
+
# Extract dimensions for reshaping
|
103
|
+
batch_size, seq_len, height, width, channels = z.shape
|
104
|
+
|
105
|
+
# Reshape to [B*T, h, w, c] to process as regular latents
|
106
|
+
z_reshaped = z.reshape(-1, height, width, channels)
|
107
|
+
|
108
|
+
# Decode all frames
|
109
|
+
decoded = self.__decode__(z_reshaped, key=key, **kwargs)
|
110
|
+
|
111
|
+
# Reshape back to include temporal dimension [B, T, H, W, C]
|
112
|
+
decoded_shape = decoded.shape
|
113
|
+
return decoded.reshape(batch_size, seq_len, *decoded_shape[1:])
|
114
|
+
else:
|
115
|
+
# Standard latent processing
|
116
|
+
return self.__decode__(z, key=key, **kwargs)
|
117
|
+
|
118
|
+
def __call__(self, x: jnp.ndarray, key: Optional[jax.random.PRNGKey] = None, **kwargs):
|
119
|
+
"""Encode and then decode the input (autoencoder).
|
120
|
+
|
121
|
+
Args:
|
122
|
+
x: Input tensor, either [B, H, W, C] for images or [B, T, H, W, C] for videos
|
123
|
+
key: Optional random key for stochastic encoding/decoding
|
124
|
+
**kwargs: Additional arguments for encoding and decoding
|
125
|
+
|
126
|
+
Returns:
|
127
|
+
Reconstructed output with the same dimensions as input
|
128
|
+
"""
|
129
|
+
if key is not None:
|
130
|
+
encode_key, decode_key = jax.random.split(key)
|
131
|
+
else:
|
132
|
+
encode_key = decode_key = None
|
133
|
+
|
134
|
+
# Encode then decode
|
135
|
+
z = self.encode(x, key=encode_key, **kwargs)
|
136
|
+
return self.decode(z, key=decode_key, **kwargs)
|
137
|
+
|
138
|
+
@property
|
139
|
+
def spatial_scale(self) -> int:
|
140
|
+
"""Get the spatial scale factor between input and latent spaces."""
|
141
|
+
return getattr(self, "_spatial_scale", None)
|
142
|
+
|
143
|
+
@property
|
144
|
+
def name(self) -> str:
|
145
|
+
"""Get the name of the autoencoder model."""
|
14
146
|
raise NotImplementedError
|
15
147
|
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
148
|
+
@abstractmethod
|
149
|
+
def serialize(self) -> Dict[str, Any]:
|
150
|
+
"""Serialize the model parameters and configuration."""
|
151
|
+
raise NotImplementedError
|