flaxdiff 0.1.38.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. flaxdiff/data/__init__.py +5 -1
  2. flaxdiff/data/benchmark_decord.py +443 -0
  3. flaxdiff/data/dataloaders.py +608 -0
  4. flaxdiff/data/dataset_map.py +61 -6
  5. flaxdiff/data/online_loader.py +779 -150
  6. flaxdiff/data/sources/audio_utils.py +142 -0
  7. flaxdiff/data/sources/av_example.py +125 -0
  8. flaxdiff/data/sources/av_utils.py +590 -0
  9. flaxdiff/data/sources/base.py +129 -0
  10. flaxdiff/data/sources/images.py +309 -0
  11. flaxdiff/data/sources/utils.py +158 -0
  12. flaxdiff/data/sources/videos.py +250 -0
  13. flaxdiff/data/sources/voxceleb2.py +412 -0
  14. flaxdiff/inference/__init__.py +0 -0
  15. flaxdiff/inference/pipeline.py +260 -0
  16. flaxdiff/inference/utils.py +320 -0
  17. flaxdiff/inputs/__init__.py +173 -0
  18. flaxdiff/inputs/encoders.py +98 -0
  19. flaxdiff/models/__init__.py +2 -1
  20. flaxdiff/models/autoencoder/autoencoder.py +141 -9
  21. flaxdiff/models/autoencoder/diffusers.py +88 -25
  22. flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
  23. flaxdiff/models/simple_unet.py +5 -5
  24. flaxdiff/models/simple_vit.py +1 -1
  25. flaxdiff/models/unet_3d.py +446 -0
  26. flaxdiff/models/unet_3d_blocks.py +505 -0
  27. flaxdiff/samplers/common.py +358 -96
  28. flaxdiff/samplers/ddim.py +44 -5
  29. flaxdiff/schedulers/karras.py +20 -12
  30. flaxdiff/trainer/__init__.py +2 -1
  31. flaxdiff/trainer/autoencoder_trainer.py +1 -2
  32. flaxdiff/trainer/diffusion_trainer.py +33 -27
  33. flaxdiff/trainer/general_diffusion_trainer.py +583 -0
  34. flaxdiff/trainer/simple_trainer.py +48 -31
  35. flaxdiff/utils.py +128 -57
  36. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
  37. flaxdiff-0.2.0.dist-info/RECORD +64 -0
  38. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
  39. flaxdiff/data/datasets.py +0 -169
  40. flaxdiff/data/sources/gcs.py +0 -81
  41. flaxdiff/data/sources/tfds.py +0 -79
  42. flaxdiff/trainer/video_diffusion_trainer.py +0 -62
  43. flaxdiff-0.1.38.1.dist-info/RECORD +0 -50
  44. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,173 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import flax.struct as struct
4
+ import flax.linen as nn
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple
6
+ from dataclasses import dataclass
7
+ from functools import partial
8
+ import numpy as np
9
+ from jax.sharding import Mesh, PartitionSpec as P
10
+ from abc import ABC, abstractmethod
11
+
12
+ from flaxdiff.models.autoencoder import AutoEncoder
13
+ from .encoders import *
14
+
15
+ @dataclass
16
+ class ConditionalInputConfig:
17
+ """Class representing a conditional input for the model."""
18
+ encoder: ConditioningEncoder
19
+ conditioning_data_key: str = None # Key in the batch for this conditioning input
20
+ pretokenized: bool = False
21
+ unconditional_input: Any = None
22
+ model_key_override: Optional[str] = None # Optional key override for the model
23
+
24
+ __uncond_cache__ = None # Cache for unconditional input
25
+
26
+ def __post_init__(self):
27
+ if self.unconditional_input is not None:
28
+ uncond = self.encoder([self.unconditional_input])
29
+ else:
30
+ uncond = self.encoder([""]) # Default empty text
31
+ self.__uncond_cache__ = uncond # Cache the unconditional input
32
+
33
+ def __call__(self, batch_data):
34
+ """Process batch data to produce conditioning."""
35
+ key = self.conditioning_data_key if self.conditioning_data_key else self.encoder.key
36
+ if self.pretokenized:
37
+ return self.encoder.encode_from_tokens(batch_data[key])
38
+ return self.encoder(batch_data[key])
39
+
40
+ def get_unconditional(self):
41
+ """Get unconditional version of this input."""
42
+ return self.__uncond_cache__
43
+
44
+ def serialize(self):
45
+ """Serialize the configuration."""
46
+ serialized_config = {
47
+ "encoder": self.encoder.serialize(),
48
+ "encoder_key": self.encoder.key,
49
+ "conditioning_data_key": self.conditioning_data_key,
50
+ "unconditional_input": self.unconditional_input,
51
+ "model_key_override": self.model_key_override,
52
+ }
53
+ return serialized_config
54
+
55
+ @staticmethod
56
+ def deserialize(serialized_config):
57
+ """Deserialize the configuration."""
58
+ encoder_key = serialized_config["encoder_key"]
59
+ encoder_class = CONDITIONAL_ENCODERS_REGISTRY.get(encoder_key)
60
+ if encoder_class is None:
61
+ raise ValueError(f"Unknown encoder type: {encoder_key}")
62
+
63
+ # Create the encoder instance
64
+ encoder = encoder_class.deserialize(serialized_config["encoder"])
65
+ # Deserialize the rest of the configuration
66
+ conditioning_data_key = serialized_config.get("conditioning_data_key")
67
+ unconditional_input = serialized_config.get("unconditional_input")
68
+ model_key_override = serialized_config.get("model_key_override")
69
+ return ConditionalInputConfig(
70
+ encoder=encoder,
71
+ conditioning_data_key=conditioning_data_key,
72
+ unconditional_input=unconditional_input,
73
+ model_key_override=model_key_override,
74
+ )
75
+
76
+ @dataclass
77
+ class DiffusionInputConfig:
78
+ """Configuration for the input data."""
79
+ sample_data_key: str # Key in the batch for the sample data
80
+ sample_data_shape: Tuple[int, ...]
81
+ conditions: List[ConditionalInputConfig]
82
+
83
+ def get_input_shapes(
84
+ self,
85
+ autoencoder: AutoEncoder = None,
86
+ sample_model_key:str = 'x',
87
+ time_embeddings_model_key:str = 'temb',
88
+ ) -> Dict[str, Tuple[int, ...]]:
89
+ """Get the shapes of the input data."""
90
+ if len(self.sample_data_shape) == 3:
91
+ H, W, C = self.sample_data_shape
92
+ elif len(self.sample_data_shape) == 4:
93
+ T, H, W, C = self.sample_data_shape
94
+ else:
95
+ raise ValueError(f"Unsupported shape for sample data {self.sample_data_shape}")
96
+ if autoencoder is not None:
97
+ downscale_factor = autoencoder.downscale_factor
98
+ H = H // downscale_factor
99
+ W = W // downscale_factor
100
+ C = autoencoder.latent_channels
101
+
102
+ input_shapes = {
103
+ sample_model_key: (H, W, C),
104
+ time_embeddings_model_key: (),
105
+ }
106
+ for cond in self.conditions:
107
+ # Get the shape of the conditioning data by calling the get_unconditional method
108
+ unconditional = cond.get_unconditional()
109
+ key = cond.model_key_override if cond.model_key_override else cond.encoder.key
110
+ input_shapes[key] = unconditional[0].shape
111
+
112
+ print(f"Calculated input shapes: {input_shapes}")
113
+ return input_shapes
114
+
115
+ def get_unconditionals(self):
116
+ """Get unconditional inputs for all conditions."""
117
+ unconditionals = []
118
+ for cond in self.conditions:
119
+ uncond = cond.get_unconditional()
120
+ unconditionals.append(uncond)
121
+ return unconditionals
122
+
123
+ def process_conditioning(self, batch_data, uncond_mask: Optional[jnp.ndarray] = None):
124
+ """Process the conditioning data."""
125
+ results = []
126
+
127
+ for cond in self.conditions:
128
+ cond_embeddings = cond(batch_data)
129
+ if uncond_mask is not None:
130
+ assert len(uncond_mask) == len(cond_embeddings), "Unconditional mask length must match the batch size."
131
+ uncond_embedding = cond.get_unconditional()
132
+
133
+ # Reshape uncond_mask to be broadcastable with the conditioning embeddings
134
+ # If cond_embeddings has shape (B, T, D), reshape uncond_mask to (B, 1, 1)
135
+ broadcast_shape = [len(uncond_mask)] + [1] * (cond_embeddings.ndim - 1)
136
+ reshaped_mask = jnp.reshape(uncond_mask, broadcast_shape)
137
+
138
+ # Repeat uncond_embedding to match batch size
139
+ batch_size = len(cond_embeddings)
140
+ repeated_uncond = jnp.repeat(uncond_embedding, batch_size, axis=0)
141
+
142
+ # Apply unconditional embedding based on the mask
143
+ cond_embeddings = jnp.where(reshaped_mask, repeated_uncond, cond_embeddings)
144
+
145
+ results.append(cond_embeddings)
146
+ return results
147
+
148
+ def serialize(self):
149
+ """Serialize the configuration."""
150
+ serialized_config = {
151
+ "sample_data_key": self.sample_data_key,
152
+ "sample_data_shape": self.sample_data_shape,
153
+ "conditions": [cond.serialize() for cond in self.conditions],
154
+ }
155
+ return serialized_config
156
+
157
+ @staticmethod
158
+ def deserialize(serialized_config):
159
+ """Deserialize the configuration."""
160
+ sample_data_key = serialized_config["sample_data_key"]
161
+ sample_data_shape = tuple(serialized_config["sample_data_shape"])
162
+ conditions = serialized_config["conditions"]
163
+
164
+ # Deserialize each condition
165
+ deserialized_conditions = []
166
+ for cond in conditions:
167
+ deserialized_conditions.append(ConditionalInputConfig.deserialize(cond))
168
+
169
+ return DiffusionInputConfig(
170
+ sample_data_key=sample_data_key,
171
+ sample_data_shape=sample_data_shape,
172
+ conditions=deserialized_conditions,
173
+ )
@@ -0,0 +1,98 @@
1
+ import jax.numpy as jnp
2
+ import flax.linen as nn
3
+ from typing import Callable
4
+ from dataclasses import dataclass
5
+ from abc import ABC, abstractmethod
6
+
7
+ @dataclass
8
+ class ConditioningEncoder(ABC):
9
+ model: nn.Module
10
+ tokenizer: Callable
11
+
12
+ @property
13
+ def key(self):
14
+ name = self.tokenizer.__name__
15
+ # Remove the 'Encoder' suffix from the name and lowercase it
16
+ if name.endswith("Encoder"):
17
+ name = name[:-7].lower()
18
+ return name
19
+
20
+ def __call__(self, data):
21
+ tokens = self.tokenize(data)
22
+ outputs = self.encode_from_tokens(tokens)
23
+ return outputs
24
+
25
+ def encode_from_tokens(self, tokens):
26
+ outputs = self.model(input_ids=tokens['input_ids'],
27
+ attention_mask=tokens['attention_mask'])
28
+ last_hidden_state = outputs.last_hidden_state
29
+ return last_hidden_state
30
+
31
+ def tokenize(self, data):
32
+ tokens = self.tokenizer(data, padding="max_length",
33
+ max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np")
34
+ return tokens
35
+
36
+ @abstractmethod
37
+ def serialize(self):
38
+ """Serialize the encoder configuration."""
39
+ pass
40
+
41
+ @staticmethod
42
+ @abstractmethod
43
+ def deserialize(serialized_config):
44
+ """Deserialize the encoder configuration."""
45
+ pass
46
+
47
+ @dataclass
48
+ class TextEncoder(ConditioningEncoder):
49
+ """Text Encoder."""
50
+ @property
51
+ def key(self):
52
+ return "text"
53
+
54
+ @dataclass
55
+ class CLIPTextEncoder(TextEncoder):
56
+ """CLIP Text Encoder."""
57
+ modelname: str
58
+ backend: str
59
+
60
+ @staticmethod
61
+ def from_modelname(modelname: str = "openai/clip-vit-large-patch14", backend: str="jax"):
62
+ from transformers import (
63
+ CLIPTextModel,
64
+ FlaxCLIPTextModel,
65
+ AutoTokenizer,
66
+ )
67
+ modelname = "openai/clip-vit-large-patch14"
68
+ if backend == "jax":
69
+ model = FlaxCLIPTextModel.from_pretrained(
70
+ modelname, dtype=jnp.bfloat16)
71
+ else:
72
+ model = CLIPTextModel.from_pretrained(modelname)
73
+ tokenizer = AutoTokenizer.from_pretrained(modelname, dtype=jnp.float16)
74
+ return CLIPTextEncoder(
75
+ model=model,
76
+ tokenizer=tokenizer,
77
+ modelname=modelname,
78
+ backend=backend
79
+ )
80
+
81
+ def serialize(self):
82
+ """Serialize the encoder configuration."""
83
+ serialized_config = {
84
+ "modelname": self.modelname,
85
+ "backend": self.backend,
86
+ }
87
+ return serialized_config
88
+
89
+ @staticmethod
90
+ def deserialize(serialized_config):
91
+ """Deserialize the encoder configuration."""
92
+ modelname = serialized_config["modelname"]
93
+ backend = serialized_config["backend"]
94
+ return CLIPTextEncoder.from_modelname(modelname=modelname, backend=backend)
95
+
96
+ CONDITIONAL_ENCODERS_REGISTRY = {
97
+ "text": CLIPTextEncoder,
98
+ }
@@ -1 +1,2 @@
1
- from .simple_unet import *
1
+ from .simple_unet import *
2
+ # from .video_unet import FlaxUNet3DConditionModel, BCHWModelWrapper, FlaxTemporalConvLayer
@@ -1,19 +1,151 @@
1
1
  import jax
2
2
  import jax.numpy as jnp
3
3
  from flax import linen as nn
4
- from typing import Dict, Callable, Sequence, Any, Union
4
+ from typing import Dict, Callable, Sequence, Any, Union, Optional
5
5
  import einops
6
6
  from ..common import kernel_init, ConvLayer, Upsample, Downsample, PixelShuffle
7
+ from dataclasses import dataclass
8
+ from abc import ABC, abstractmethod
7
9
 
8
-
9
- class AutoEncoder():
10
- def encode(self, x: jnp.ndarray, **kwargs) -> jnp.ndarray:
10
+ @dataclass
11
+ class AutoEncoder(ABC):
12
+ """Base class for autoencoder models with video support.
13
+
14
+ This class defines the interface for autoencoders and provides
15
+ video handling functionality, allowing child classes to focus
16
+ on implementing the core encoding/decoding for individual frames.
17
+ """
18
+ @abstractmethod
19
+ def __encode__(self, x: jnp.ndarray, **kwargs) -> jnp.ndarray:
20
+ """Abstract method for encoding a batch of images.
21
+
22
+ Child classes must implement this method to perform the actual encoding.
23
+
24
+ Args:
25
+ x: Input tensor of shape [B, H, W, C] (batch of images)
26
+ **kwargs: Additional arguments for the encoding process
27
+
28
+ Returns:
29
+ Encoded latent representation
30
+ """
31
+ raise NotImplementedError
32
+
33
+ @abstractmethod
34
+ def __decode__(self, z: jnp.ndarray, **kwargs) -> jnp.ndarray:
35
+ """Abstract method for decoding a batch of latents.
36
+
37
+ Child classes must implement this method to perform the actual decoding.
38
+
39
+ Args:
40
+ z: Latent tensor of shape [B, h, w, c] (encoded representation)
41
+ **kwargs: Additional arguments for the decoding process
42
+
43
+ Returns:
44
+ Decoded images
45
+ """
11
46
  raise NotImplementedError
12
47
 
13
- def decode(self, z: jnp.ndarray, **kwargs) -> jnp.ndarray:
48
+ def encode(self, x: jnp.ndarray, key: Optional[jax.random.PRNGKey] = None, **kwargs) -> jnp.ndarray:
49
+ """Encode input data, with special handling for video data.
50
+
51
+ This method handles both standard image batches and video data (5D tensors).
52
+ For videos, it reshapes the input, processes each frame, and then restores
53
+ the temporal dimension.
54
+
55
+ Args:
56
+ x: Input tensor, either [B, H, W, C] for images or [B, T, H, W, C] for videos
57
+ key: Optional random key for stochastic encoding
58
+ **kwargs: Additional arguments passed to __encode__
59
+
60
+ Returns:
61
+ Encoded representation with the same batch and temporal dimensions as input
62
+ """
63
+ # Check for video data (5D tensor)
64
+ is_video = len(x.shape) == 5
65
+
66
+ if is_video:
67
+ # Extract dimensions for reshaping
68
+ batch_size, seq_len, height, width, channels = x.shape
69
+
70
+ # Reshape to [B*T, H, W, C] to process as regular images
71
+ x_reshaped = x.reshape(-1, height, width, channels)
72
+
73
+ # Encode all frames
74
+ latent = self.__encode__(x_reshaped, key=key, **kwargs)
75
+
76
+ # Reshape back to include temporal dimension [B, T, h, w, c]
77
+ latent_shape = latent.shape
78
+ return latent.reshape(batch_size, seq_len, *latent_shape[1:])
79
+ else:
80
+ # Standard image processing
81
+ return self.__encode__(x, key=key, **kwargs)
82
+
83
+ def decode(self, z: jnp.ndarray, key: Optional[jax.random.PRNGKey] = None, **kwargs) -> jnp.ndarray:
84
+ """Decode latent representations, with special handling for video data.
85
+
86
+ This method handles both standard image latents and video latents (5D tensors).
87
+ For videos, it reshapes the input, processes each frame, and then restores
88
+ the temporal dimension.
89
+
90
+ Args:
91
+ z: Latent tensor, either [B, h, w, c] for images or [B, T, h, w, c] for videos
92
+ key: Optional random key for stochastic decoding
93
+ **kwargs: Additional arguments passed to __decode__
94
+
95
+ Returns:
96
+ Decoded output with the same batch and temporal dimensions as input
97
+ """
98
+ # Check for video data (5D tensor)
99
+ is_video = len(z.shape) == 5
100
+
101
+ if is_video:
102
+ # Extract dimensions for reshaping
103
+ batch_size, seq_len, height, width, channels = z.shape
104
+
105
+ # Reshape to [B*T, h, w, c] to process as regular latents
106
+ z_reshaped = z.reshape(-1, height, width, channels)
107
+
108
+ # Decode all frames
109
+ decoded = self.__decode__(z_reshaped, key=key, **kwargs)
110
+
111
+ # Reshape back to include temporal dimension [B, T, H, W, C]
112
+ decoded_shape = decoded.shape
113
+ return decoded.reshape(batch_size, seq_len, *decoded_shape[1:])
114
+ else:
115
+ # Standard latent processing
116
+ return self.__decode__(z, key=key, **kwargs)
117
+
118
+ def __call__(self, x: jnp.ndarray, key: Optional[jax.random.PRNGKey] = None, **kwargs):
119
+ """Encode and then decode the input (autoencoder).
120
+
121
+ Args:
122
+ x: Input tensor, either [B, H, W, C] for images or [B, T, H, W, C] for videos
123
+ key: Optional random key for stochastic encoding/decoding
124
+ **kwargs: Additional arguments for encoding and decoding
125
+
126
+ Returns:
127
+ Reconstructed output with the same dimensions as input
128
+ """
129
+ if key is not None:
130
+ encode_key, decode_key = jax.random.split(key)
131
+ else:
132
+ encode_key = decode_key = None
133
+
134
+ # Encode then decode
135
+ z = self.encode(x, key=encode_key, **kwargs)
136
+ return self.decode(z, key=decode_key, **kwargs)
137
+
138
+ @property
139
+ def spatial_scale(self) -> int:
140
+ """Get the spatial scale factor between input and latent spaces."""
141
+ return getattr(self, "_spatial_scale", None)
142
+
143
+ @property
144
+ def name(self) -> str:
145
+ """Get the name of the autoencoder model."""
14
146
  raise NotImplementedError
15
147
 
16
- def __call__(self, x: jnp.ndarray):
17
- latents = self.encode(x)
18
- reconstructions = self.decode(latents)
19
- return reconstructions
148
+ @abstractmethod
149
+ def serialize(self) -> Dict[str, Any]:
150
+ """Serialize the model parameters and configuration."""
151
+ raise NotImplementedError
@@ -22,7 +22,9 @@ class StableDiffusionVAE(AutoEncoder):
22
22
  dtype=dtype,
23
23
  )
24
24
 
25
- # vae = pipeline.vae
25
+ self.modelname = modelname
26
+ self.revision = revision
27
+ self.dtype = dtype
26
28
 
27
29
  enc = FlaxEncoder(
28
30
  in_channels=vae.config.in_channels,
@@ -63,29 +65,90 @@ class StableDiffusionVAE(AutoEncoder):
63
65
  dtype=vae.dtype,
64
66
  )
65
67
 
66
- self.enc = enc
67
- self.dec = dec
68
- self.post_quant_conv = post_quant_conv
69
- self.quant_conv = quant_conv
70
- self.params = params
71
- self.scaling_factor = vae.scaling_factor
68
+ scaling_factor = vae.scaling_factor
69
+ print(f"Scaling factor: {scaling_factor}")
72
70
 
73
- def encode(self, images, rngkey: jax.random.PRNGKey = None):
74
- latents = self.enc.apply({"params": self.params["vae"]['encoder']}, images, deterministic=True)
75
- latents = self.quant_conv.apply({"params": self.params["vae"]['quant_conv']}, latents)
76
- if rngkey is not None:
77
- mean, log_std = jnp.split(latents, 2, axis=-1)
78
- log_std = jnp.clip(log_std, -30, 20)
79
- std = jnp.exp(0.5 * log_std)
80
- latents = mean + std * jax.random.normal(rngkey, mean.shape, dtype=mean.dtype)
81
- # print("Sampled")
82
- else:
83
- # return the mean
84
- latents, _ = jnp.split(latents, 2, axis=-1)
85
- latents *= self.scaling_factor
86
- return latents
71
+ def encode_single_frame(images, rngkey: jax.random.PRNGKey = None):
72
+ latents = enc.apply({"params": params['encoder']}, images, deterministic=True)
73
+ latents = quant_conv.apply({"params": params['quant_conv']}, latents)
74
+ if rngkey is not None:
75
+ mean, log_std = jnp.split(latents, 2, axis=-1)
76
+ log_std = jnp.clip(log_std, -30, 20)
77
+ std = jnp.exp(0.5 * log_std)
78
+ latents = mean + std * jax.random.normal(rngkey, mean.shape, dtype=mean.dtype)
79
+ else:
80
+ latents, _ = jnp.split(latents, 2, axis=-1)
81
+ latents *= scaling_factor
82
+ return latents
83
+
84
+ def decode_single_frame(latents):
85
+ latents = (1.0 / scaling_factor) * latents
86
+ latents = post_quant_conv.apply({"params": params['post_quant_conv']}, latents)
87
+ return dec.apply({"params": params['decoder']}, latents)
88
+
89
+ self.encode_single_frame = jax.jit(encode_single_frame)
90
+ self.decode_single_frame = jax.jit(decode_single_frame)
91
+
92
+ # Calculate downscale factor by passing a dummy input through the encoder
93
+ print("Calculating downscale factor...")
94
+ dummy_input = jnp.ones((1, 128, 128, 3), dtype=dtype)
95
+ dummy_latents = self.encode_single_frame(dummy_input)
96
+ _, h, w, c = dummy_latents.shape
97
+ _, H, W, C = dummy_input.shape
98
+ self.__downscale_factor__ = H // h
99
+ self.__latent_channels__ = c
100
+ print(f"Downscale factor: {self.__downscale_factor__}")
101
+ print(f"Latent channels: {self.__latent_channels__}")
102
+
103
+ def __encode__(self, images, key: jax.random.PRNGKey = None, **kwargs):
104
+ """Encode a batch of images to latent representations.
105
+
106
+ Implements the abstract method from the parent class.
107
+
108
+ Args:
109
+ images: Image tensor of shape [B, H, W, C]
110
+ key: Optional random key for stochastic encoding
111
+ **kwargs: Additional arguments (unused)
112
+
113
+ Returns:
114
+ Latent representations of shape [B, h, w, c]
115
+ """
116
+ return self.encode_single_frame(images, key)
117
+
118
+ def __decode__(self, latents, **kwargs):
119
+ """Decode latent representations to images.
120
+
121
+ Implements the abstract method from the parent class.
122
+
123
+ Args:
124
+ latents: Latent tensor of shape [B, h, w, c]
125
+ **kwargs: Additional arguments (unused)
126
+
127
+ Returns:
128
+ Decoded images of shape [B, H, W, C]
129
+ """
130
+ return self.decode_single_frame(latents)
131
+
132
+ @property
133
+ def downscale_factor(self) -> int:
134
+ """Returns the downscale factor for the encoder."""
135
+ return self.__downscale_factor__
136
+
137
+ @property
138
+ def latent_channels(self) -> int:
139
+ """Returns the number of channels in the latent space."""
140
+ return self.__latent_channels__
141
+
142
+ @property
143
+ def name(self) -> str:
144
+ """Get the name of the autoencoder model."""
145
+ return "stable_diffusion"
87
146
 
88
- def decode(self, latents):
89
- latents = (1.0 / self.scaling_factor) * latents
90
- latents = self.post_quant_conv.apply({"params": self.params["vae"]['post_quant_conv']}, latents)
91
- return self.dec.apply({"params": self.params["vae"]['decoder']}, latents)
147
+ def serialize(self):
148
+ """Serialize the model to a dictionary format."""
149
+ return {
150
+ "modelname": self.modelname,
151
+ "revision": self.revision,
152
+ "dtype": str(self.dtype),
153
+ }
154
+
@@ -6,21 +6,53 @@ from flax.typing import Dtype, PrecisionLike
6
6
  from .autoencoder import AutoEncoder
7
7
 
8
8
  class SimpleAutoEncoder(AutoEncoder):
9
+ """A simple autoencoder implementation using the abstract method pattern.
10
+
11
+ This implementation allows for handling both image and video data through
12
+ the parent class's handling of video reshaping.
13
+ """
9
14
  latent_channels: int
10
15
  feature_depths: List[int]=[64, 128, 256, 512]
11
- attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}],
16
+ attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}]
12
17
  num_res_blocks: int=2
13
- num_middle_res_blocks:int=1,
18
+ num_middle_res_blocks:int=1
14
19
  activation:Callable = jax.nn.swish
15
20
  norm_groups:int=8
16
21
  dtype: Optional[Dtype] = None
17
22
  precision: PrecisionLike = None
18
23
 
19
- # def encode(self, x: jnp.ndarray):
24
+ def __encode__(self, x: jnp.ndarray, **kwargs):
25
+ """Encode a batch of images to latent representations.
26
+
27
+ Implements the abstract method from the parent class.
20
28
 
29
+ Args:
30
+ x: Image tensor of shape [B, H, W, C]
31
+ **kwargs: Additional arguments
32
+
33
+ Returns:
34
+ Latent representations of shape [B, h, w, c]
35
+ """
36
+ # TODO: Implement the actual encoding logic for single frames
37
+ # This is just a placeholder implementation
38
+ B, H, W, C = x.shape
39
+ h, w = H // 8, W // 8 # Example downsampling factor
40
+ return jnp.zeros((B, h, w, self.latent_channels))
21
41
 
22
- @nn.compact
23
- def __call__(self, x: jnp.ndarray):
24
- latents = self.encode(x)
25
- reconstructions = self.decode(latents)
26
- return reconstructions
42
+ def __decode__(self, z: jnp.ndarray, **kwargs):
43
+ """Decode latent representations to images.
44
+
45
+ Implements the abstract method from the parent class.
46
+
47
+ Args:
48
+ z: Latent tensor of shape [B, h, w, c]
49
+ **kwargs: Additional arguments
50
+
51
+ Returns:
52
+ Decoded images of shape [B, H, W, C]
53
+ """
54
+ # TODO: Implement the actual decoding logic for single frames
55
+ # This is just a placeholder implementation
56
+ B, h, w, c = z.shape
57
+ H, W = h * 8, w * 8 # Example upsampling factor
58
+ return jnp.zeros((B, H, W, 3))
@@ -10,11 +10,11 @@ from functools import partial
10
10
 
11
11
  class Unet(nn.Module):
12
12
  output_channels:int=3
13
- emb_features:int=64*4,
14
- feature_depths:list=[64, 128, 256, 512],
15
- attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}],
16
- num_res_blocks:int=2,
17
- num_middle_res_blocks:int=1,
13
+ emb_features:int=64*4
14
+ feature_depths:list=(64, 128, 256, 512)
15
+ attention_configs:list=({"heads":8}, {"heads":8}, {"heads":8}, {"heads":8})
16
+ num_res_blocks:int=2
17
+ num_middle_res_blocks:int=1
18
18
  activation:Callable = jax.nn.swish
19
19
  norm_groups:int=8
20
20
  dtype: Optional[Dtype] = None
@@ -51,7 +51,7 @@ class PositionalEncoding(nn.Module):
51
51
  class UViT(nn.Module):
52
52
  output_channels:int=3
53
53
  patch_size: int = 16
54
- emb_features:int=768,
54
+ emb_features:int=768
55
55
  num_layers: int = 12
56
56
  num_heads: int = 12
57
57
  dropout_rate: float = 0.1