flaxdiff 0.1.38.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/__init__.py +5 -1
- flaxdiff/data/benchmark_decord.py +443 -0
- flaxdiff/data/dataloaders.py +608 -0
- flaxdiff/data/dataset_map.py +61 -6
- flaxdiff/data/online_loader.py +779 -150
- flaxdiff/data/sources/audio_utils.py +142 -0
- flaxdiff/data/sources/av_example.py +125 -0
- flaxdiff/data/sources/av_utils.py +590 -0
- flaxdiff/data/sources/base.py +129 -0
- flaxdiff/data/sources/images.py +309 -0
- flaxdiff/data/sources/utils.py +158 -0
- flaxdiff/data/sources/videos.py +250 -0
- flaxdiff/data/sources/voxceleb2.py +412 -0
- flaxdiff/inference/__init__.py +0 -0
- flaxdiff/inference/pipeline.py +260 -0
- flaxdiff/inference/utils.py +320 -0
- flaxdiff/inputs/__init__.py +173 -0
- flaxdiff/inputs/encoders.py +98 -0
- flaxdiff/models/__init__.py +2 -1
- flaxdiff/models/autoencoder/autoencoder.py +141 -9
- flaxdiff/models/autoencoder/diffusers.py +88 -25
- flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
- flaxdiff/models/simple_unet.py +5 -5
- flaxdiff/models/simple_vit.py +1 -1
- flaxdiff/models/unet_3d.py +446 -0
- flaxdiff/models/unet_3d_blocks.py +505 -0
- flaxdiff/samplers/common.py +358 -96
- flaxdiff/samplers/ddim.py +44 -5
- flaxdiff/schedulers/karras.py +20 -12
- flaxdiff/trainer/__init__.py +2 -1
- flaxdiff/trainer/autoencoder_trainer.py +1 -2
- flaxdiff/trainer/diffusion_trainer.py +33 -27
- flaxdiff/trainer/general_diffusion_trainer.py +583 -0
- flaxdiff/trainer/simple_trainer.py +48 -31
- flaxdiff/utils.py +128 -57
- {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
- flaxdiff-0.2.0.dist-info/RECORD +64 -0
- {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
- flaxdiff/data/datasets.py +0 -169
- flaxdiff/data/sources/gcs.py +0 -81
- flaxdiff/data/sources/tfds.py +0 -79
- flaxdiff/trainer/video_diffusion_trainer.py +0 -62
- flaxdiff-0.1.38.1.dist-info/RECORD +0 -50
- {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,173 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
import flax.struct as struct
|
4
|
+
import flax.linen as nn
|
5
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
6
|
+
from dataclasses import dataclass
|
7
|
+
from functools import partial
|
8
|
+
import numpy as np
|
9
|
+
from jax.sharding import Mesh, PartitionSpec as P
|
10
|
+
from abc import ABC, abstractmethod
|
11
|
+
|
12
|
+
from flaxdiff.models.autoencoder import AutoEncoder
|
13
|
+
from .encoders import *
|
14
|
+
|
15
|
+
@dataclass
|
16
|
+
class ConditionalInputConfig:
|
17
|
+
"""Class representing a conditional input for the model."""
|
18
|
+
encoder: ConditioningEncoder
|
19
|
+
conditioning_data_key: str = None # Key in the batch for this conditioning input
|
20
|
+
pretokenized: bool = False
|
21
|
+
unconditional_input: Any = None
|
22
|
+
model_key_override: Optional[str] = None # Optional key override for the model
|
23
|
+
|
24
|
+
__uncond_cache__ = None # Cache for unconditional input
|
25
|
+
|
26
|
+
def __post_init__(self):
|
27
|
+
if self.unconditional_input is not None:
|
28
|
+
uncond = self.encoder([self.unconditional_input])
|
29
|
+
else:
|
30
|
+
uncond = self.encoder([""]) # Default empty text
|
31
|
+
self.__uncond_cache__ = uncond # Cache the unconditional input
|
32
|
+
|
33
|
+
def __call__(self, batch_data):
|
34
|
+
"""Process batch data to produce conditioning."""
|
35
|
+
key = self.conditioning_data_key if self.conditioning_data_key else self.encoder.key
|
36
|
+
if self.pretokenized:
|
37
|
+
return self.encoder.encode_from_tokens(batch_data[key])
|
38
|
+
return self.encoder(batch_data[key])
|
39
|
+
|
40
|
+
def get_unconditional(self):
|
41
|
+
"""Get unconditional version of this input."""
|
42
|
+
return self.__uncond_cache__
|
43
|
+
|
44
|
+
def serialize(self):
|
45
|
+
"""Serialize the configuration."""
|
46
|
+
serialized_config = {
|
47
|
+
"encoder": self.encoder.serialize(),
|
48
|
+
"encoder_key": self.encoder.key,
|
49
|
+
"conditioning_data_key": self.conditioning_data_key,
|
50
|
+
"unconditional_input": self.unconditional_input,
|
51
|
+
"model_key_override": self.model_key_override,
|
52
|
+
}
|
53
|
+
return serialized_config
|
54
|
+
|
55
|
+
@staticmethod
|
56
|
+
def deserialize(serialized_config):
|
57
|
+
"""Deserialize the configuration."""
|
58
|
+
encoder_key = serialized_config["encoder_key"]
|
59
|
+
encoder_class = CONDITIONAL_ENCODERS_REGISTRY.get(encoder_key)
|
60
|
+
if encoder_class is None:
|
61
|
+
raise ValueError(f"Unknown encoder type: {encoder_key}")
|
62
|
+
|
63
|
+
# Create the encoder instance
|
64
|
+
encoder = encoder_class.deserialize(serialized_config["encoder"])
|
65
|
+
# Deserialize the rest of the configuration
|
66
|
+
conditioning_data_key = serialized_config.get("conditioning_data_key")
|
67
|
+
unconditional_input = serialized_config.get("unconditional_input")
|
68
|
+
model_key_override = serialized_config.get("model_key_override")
|
69
|
+
return ConditionalInputConfig(
|
70
|
+
encoder=encoder,
|
71
|
+
conditioning_data_key=conditioning_data_key,
|
72
|
+
unconditional_input=unconditional_input,
|
73
|
+
model_key_override=model_key_override,
|
74
|
+
)
|
75
|
+
|
76
|
+
@dataclass
|
77
|
+
class DiffusionInputConfig:
|
78
|
+
"""Configuration for the input data."""
|
79
|
+
sample_data_key: str # Key in the batch for the sample data
|
80
|
+
sample_data_shape: Tuple[int, ...]
|
81
|
+
conditions: List[ConditionalInputConfig]
|
82
|
+
|
83
|
+
def get_input_shapes(
|
84
|
+
self,
|
85
|
+
autoencoder: AutoEncoder = None,
|
86
|
+
sample_model_key:str = 'x',
|
87
|
+
time_embeddings_model_key:str = 'temb',
|
88
|
+
) -> Dict[str, Tuple[int, ...]]:
|
89
|
+
"""Get the shapes of the input data."""
|
90
|
+
if len(self.sample_data_shape) == 3:
|
91
|
+
H, W, C = self.sample_data_shape
|
92
|
+
elif len(self.sample_data_shape) == 4:
|
93
|
+
T, H, W, C = self.sample_data_shape
|
94
|
+
else:
|
95
|
+
raise ValueError(f"Unsupported shape for sample data {self.sample_data_shape}")
|
96
|
+
if autoencoder is not None:
|
97
|
+
downscale_factor = autoencoder.downscale_factor
|
98
|
+
H = H // downscale_factor
|
99
|
+
W = W // downscale_factor
|
100
|
+
C = autoencoder.latent_channels
|
101
|
+
|
102
|
+
input_shapes = {
|
103
|
+
sample_model_key: (H, W, C),
|
104
|
+
time_embeddings_model_key: (),
|
105
|
+
}
|
106
|
+
for cond in self.conditions:
|
107
|
+
# Get the shape of the conditioning data by calling the get_unconditional method
|
108
|
+
unconditional = cond.get_unconditional()
|
109
|
+
key = cond.model_key_override if cond.model_key_override else cond.encoder.key
|
110
|
+
input_shapes[key] = unconditional[0].shape
|
111
|
+
|
112
|
+
print(f"Calculated input shapes: {input_shapes}")
|
113
|
+
return input_shapes
|
114
|
+
|
115
|
+
def get_unconditionals(self):
|
116
|
+
"""Get unconditional inputs for all conditions."""
|
117
|
+
unconditionals = []
|
118
|
+
for cond in self.conditions:
|
119
|
+
uncond = cond.get_unconditional()
|
120
|
+
unconditionals.append(uncond)
|
121
|
+
return unconditionals
|
122
|
+
|
123
|
+
def process_conditioning(self, batch_data, uncond_mask: Optional[jnp.ndarray] = None):
|
124
|
+
"""Process the conditioning data."""
|
125
|
+
results = []
|
126
|
+
|
127
|
+
for cond in self.conditions:
|
128
|
+
cond_embeddings = cond(batch_data)
|
129
|
+
if uncond_mask is not None:
|
130
|
+
assert len(uncond_mask) == len(cond_embeddings), "Unconditional mask length must match the batch size."
|
131
|
+
uncond_embedding = cond.get_unconditional()
|
132
|
+
|
133
|
+
# Reshape uncond_mask to be broadcastable with the conditioning embeddings
|
134
|
+
# If cond_embeddings has shape (B, T, D), reshape uncond_mask to (B, 1, 1)
|
135
|
+
broadcast_shape = [len(uncond_mask)] + [1] * (cond_embeddings.ndim - 1)
|
136
|
+
reshaped_mask = jnp.reshape(uncond_mask, broadcast_shape)
|
137
|
+
|
138
|
+
# Repeat uncond_embedding to match batch size
|
139
|
+
batch_size = len(cond_embeddings)
|
140
|
+
repeated_uncond = jnp.repeat(uncond_embedding, batch_size, axis=0)
|
141
|
+
|
142
|
+
# Apply unconditional embedding based on the mask
|
143
|
+
cond_embeddings = jnp.where(reshaped_mask, repeated_uncond, cond_embeddings)
|
144
|
+
|
145
|
+
results.append(cond_embeddings)
|
146
|
+
return results
|
147
|
+
|
148
|
+
def serialize(self):
|
149
|
+
"""Serialize the configuration."""
|
150
|
+
serialized_config = {
|
151
|
+
"sample_data_key": self.sample_data_key,
|
152
|
+
"sample_data_shape": self.sample_data_shape,
|
153
|
+
"conditions": [cond.serialize() for cond in self.conditions],
|
154
|
+
}
|
155
|
+
return serialized_config
|
156
|
+
|
157
|
+
@staticmethod
|
158
|
+
def deserialize(serialized_config):
|
159
|
+
"""Deserialize the configuration."""
|
160
|
+
sample_data_key = serialized_config["sample_data_key"]
|
161
|
+
sample_data_shape = tuple(serialized_config["sample_data_shape"])
|
162
|
+
conditions = serialized_config["conditions"]
|
163
|
+
|
164
|
+
# Deserialize each condition
|
165
|
+
deserialized_conditions = []
|
166
|
+
for cond in conditions:
|
167
|
+
deserialized_conditions.append(ConditionalInputConfig.deserialize(cond))
|
168
|
+
|
169
|
+
return DiffusionInputConfig(
|
170
|
+
sample_data_key=sample_data_key,
|
171
|
+
sample_data_shape=sample_data_shape,
|
172
|
+
conditions=deserialized_conditions,
|
173
|
+
)
|
@@ -0,0 +1,98 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
import flax.linen as nn
|
3
|
+
from typing import Callable
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from abc import ABC, abstractmethod
|
6
|
+
|
7
|
+
@dataclass
|
8
|
+
class ConditioningEncoder(ABC):
|
9
|
+
model: nn.Module
|
10
|
+
tokenizer: Callable
|
11
|
+
|
12
|
+
@property
|
13
|
+
def key(self):
|
14
|
+
name = self.tokenizer.__name__
|
15
|
+
# Remove the 'Encoder' suffix from the name and lowercase it
|
16
|
+
if name.endswith("Encoder"):
|
17
|
+
name = name[:-7].lower()
|
18
|
+
return name
|
19
|
+
|
20
|
+
def __call__(self, data):
|
21
|
+
tokens = self.tokenize(data)
|
22
|
+
outputs = self.encode_from_tokens(tokens)
|
23
|
+
return outputs
|
24
|
+
|
25
|
+
def encode_from_tokens(self, tokens):
|
26
|
+
outputs = self.model(input_ids=tokens['input_ids'],
|
27
|
+
attention_mask=tokens['attention_mask'])
|
28
|
+
last_hidden_state = outputs.last_hidden_state
|
29
|
+
return last_hidden_state
|
30
|
+
|
31
|
+
def tokenize(self, data):
|
32
|
+
tokens = self.tokenizer(data, padding="max_length",
|
33
|
+
max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np")
|
34
|
+
return tokens
|
35
|
+
|
36
|
+
@abstractmethod
|
37
|
+
def serialize(self):
|
38
|
+
"""Serialize the encoder configuration."""
|
39
|
+
pass
|
40
|
+
|
41
|
+
@staticmethod
|
42
|
+
@abstractmethod
|
43
|
+
def deserialize(serialized_config):
|
44
|
+
"""Deserialize the encoder configuration."""
|
45
|
+
pass
|
46
|
+
|
47
|
+
@dataclass
|
48
|
+
class TextEncoder(ConditioningEncoder):
|
49
|
+
"""Text Encoder."""
|
50
|
+
@property
|
51
|
+
def key(self):
|
52
|
+
return "text"
|
53
|
+
|
54
|
+
@dataclass
|
55
|
+
class CLIPTextEncoder(TextEncoder):
|
56
|
+
"""CLIP Text Encoder."""
|
57
|
+
modelname: str
|
58
|
+
backend: str
|
59
|
+
|
60
|
+
@staticmethod
|
61
|
+
def from_modelname(modelname: str = "openai/clip-vit-large-patch14", backend: str="jax"):
|
62
|
+
from transformers import (
|
63
|
+
CLIPTextModel,
|
64
|
+
FlaxCLIPTextModel,
|
65
|
+
AutoTokenizer,
|
66
|
+
)
|
67
|
+
modelname = "openai/clip-vit-large-patch14"
|
68
|
+
if backend == "jax":
|
69
|
+
model = FlaxCLIPTextModel.from_pretrained(
|
70
|
+
modelname, dtype=jnp.bfloat16)
|
71
|
+
else:
|
72
|
+
model = CLIPTextModel.from_pretrained(modelname)
|
73
|
+
tokenizer = AutoTokenizer.from_pretrained(modelname, dtype=jnp.float16)
|
74
|
+
return CLIPTextEncoder(
|
75
|
+
model=model,
|
76
|
+
tokenizer=tokenizer,
|
77
|
+
modelname=modelname,
|
78
|
+
backend=backend
|
79
|
+
)
|
80
|
+
|
81
|
+
def serialize(self):
|
82
|
+
"""Serialize the encoder configuration."""
|
83
|
+
serialized_config = {
|
84
|
+
"modelname": self.modelname,
|
85
|
+
"backend": self.backend,
|
86
|
+
}
|
87
|
+
return serialized_config
|
88
|
+
|
89
|
+
@staticmethod
|
90
|
+
def deserialize(serialized_config):
|
91
|
+
"""Deserialize the encoder configuration."""
|
92
|
+
modelname = serialized_config["modelname"]
|
93
|
+
backend = serialized_config["backend"]
|
94
|
+
return CLIPTextEncoder.from_modelname(modelname=modelname, backend=backend)
|
95
|
+
|
96
|
+
CONDITIONAL_ENCODERS_REGISTRY = {
|
97
|
+
"text": CLIPTextEncoder,
|
98
|
+
}
|
flaxdiff/models/__init__.py
CHANGED
@@ -1 +1,2 @@
|
|
1
|
-
from .simple_unet import *
|
1
|
+
from .simple_unet import *
|
2
|
+
# from .video_unet import FlaxUNet3DConditionModel, BCHWModelWrapper, FlaxTemporalConvLayer
|
@@ -1,19 +1,151 @@
|
|
1
1
|
import jax
|
2
2
|
import jax.numpy as jnp
|
3
3
|
from flax import linen as nn
|
4
|
-
from typing import Dict, Callable, Sequence, Any, Union
|
4
|
+
from typing import Dict, Callable, Sequence, Any, Union, Optional
|
5
5
|
import einops
|
6
6
|
from ..common import kernel_init, ConvLayer, Upsample, Downsample, PixelShuffle
|
7
|
+
from dataclasses import dataclass
|
8
|
+
from abc import ABC, abstractmethod
|
7
9
|
|
8
|
-
|
9
|
-
class AutoEncoder():
|
10
|
-
|
10
|
+
@dataclass
|
11
|
+
class AutoEncoder(ABC):
|
12
|
+
"""Base class for autoencoder models with video support.
|
13
|
+
|
14
|
+
This class defines the interface for autoencoders and provides
|
15
|
+
video handling functionality, allowing child classes to focus
|
16
|
+
on implementing the core encoding/decoding for individual frames.
|
17
|
+
"""
|
18
|
+
@abstractmethod
|
19
|
+
def __encode__(self, x: jnp.ndarray, **kwargs) -> jnp.ndarray:
|
20
|
+
"""Abstract method for encoding a batch of images.
|
21
|
+
|
22
|
+
Child classes must implement this method to perform the actual encoding.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
x: Input tensor of shape [B, H, W, C] (batch of images)
|
26
|
+
**kwargs: Additional arguments for the encoding process
|
27
|
+
|
28
|
+
Returns:
|
29
|
+
Encoded latent representation
|
30
|
+
"""
|
31
|
+
raise NotImplementedError
|
32
|
+
|
33
|
+
@abstractmethod
|
34
|
+
def __decode__(self, z: jnp.ndarray, **kwargs) -> jnp.ndarray:
|
35
|
+
"""Abstract method for decoding a batch of latents.
|
36
|
+
|
37
|
+
Child classes must implement this method to perform the actual decoding.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
z: Latent tensor of shape [B, h, w, c] (encoded representation)
|
41
|
+
**kwargs: Additional arguments for the decoding process
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
Decoded images
|
45
|
+
"""
|
11
46
|
raise NotImplementedError
|
12
47
|
|
13
|
-
def
|
48
|
+
def encode(self, x: jnp.ndarray, key: Optional[jax.random.PRNGKey] = None, **kwargs) -> jnp.ndarray:
|
49
|
+
"""Encode input data, with special handling for video data.
|
50
|
+
|
51
|
+
This method handles both standard image batches and video data (5D tensors).
|
52
|
+
For videos, it reshapes the input, processes each frame, and then restores
|
53
|
+
the temporal dimension.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
x: Input tensor, either [B, H, W, C] for images or [B, T, H, W, C] for videos
|
57
|
+
key: Optional random key for stochastic encoding
|
58
|
+
**kwargs: Additional arguments passed to __encode__
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
Encoded representation with the same batch and temporal dimensions as input
|
62
|
+
"""
|
63
|
+
# Check for video data (5D tensor)
|
64
|
+
is_video = len(x.shape) == 5
|
65
|
+
|
66
|
+
if is_video:
|
67
|
+
# Extract dimensions for reshaping
|
68
|
+
batch_size, seq_len, height, width, channels = x.shape
|
69
|
+
|
70
|
+
# Reshape to [B*T, H, W, C] to process as regular images
|
71
|
+
x_reshaped = x.reshape(-1, height, width, channels)
|
72
|
+
|
73
|
+
# Encode all frames
|
74
|
+
latent = self.__encode__(x_reshaped, key=key, **kwargs)
|
75
|
+
|
76
|
+
# Reshape back to include temporal dimension [B, T, h, w, c]
|
77
|
+
latent_shape = latent.shape
|
78
|
+
return latent.reshape(batch_size, seq_len, *latent_shape[1:])
|
79
|
+
else:
|
80
|
+
# Standard image processing
|
81
|
+
return self.__encode__(x, key=key, **kwargs)
|
82
|
+
|
83
|
+
def decode(self, z: jnp.ndarray, key: Optional[jax.random.PRNGKey] = None, **kwargs) -> jnp.ndarray:
|
84
|
+
"""Decode latent representations, with special handling for video data.
|
85
|
+
|
86
|
+
This method handles both standard image latents and video latents (5D tensors).
|
87
|
+
For videos, it reshapes the input, processes each frame, and then restores
|
88
|
+
the temporal dimension.
|
89
|
+
|
90
|
+
Args:
|
91
|
+
z: Latent tensor, either [B, h, w, c] for images or [B, T, h, w, c] for videos
|
92
|
+
key: Optional random key for stochastic decoding
|
93
|
+
**kwargs: Additional arguments passed to __decode__
|
94
|
+
|
95
|
+
Returns:
|
96
|
+
Decoded output with the same batch and temporal dimensions as input
|
97
|
+
"""
|
98
|
+
# Check for video data (5D tensor)
|
99
|
+
is_video = len(z.shape) == 5
|
100
|
+
|
101
|
+
if is_video:
|
102
|
+
# Extract dimensions for reshaping
|
103
|
+
batch_size, seq_len, height, width, channels = z.shape
|
104
|
+
|
105
|
+
# Reshape to [B*T, h, w, c] to process as regular latents
|
106
|
+
z_reshaped = z.reshape(-1, height, width, channels)
|
107
|
+
|
108
|
+
# Decode all frames
|
109
|
+
decoded = self.__decode__(z_reshaped, key=key, **kwargs)
|
110
|
+
|
111
|
+
# Reshape back to include temporal dimension [B, T, H, W, C]
|
112
|
+
decoded_shape = decoded.shape
|
113
|
+
return decoded.reshape(batch_size, seq_len, *decoded_shape[1:])
|
114
|
+
else:
|
115
|
+
# Standard latent processing
|
116
|
+
return self.__decode__(z, key=key, **kwargs)
|
117
|
+
|
118
|
+
def __call__(self, x: jnp.ndarray, key: Optional[jax.random.PRNGKey] = None, **kwargs):
|
119
|
+
"""Encode and then decode the input (autoencoder).
|
120
|
+
|
121
|
+
Args:
|
122
|
+
x: Input tensor, either [B, H, W, C] for images or [B, T, H, W, C] for videos
|
123
|
+
key: Optional random key for stochastic encoding/decoding
|
124
|
+
**kwargs: Additional arguments for encoding and decoding
|
125
|
+
|
126
|
+
Returns:
|
127
|
+
Reconstructed output with the same dimensions as input
|
128
|
+
"""
|
129
|
+
if key is not None:
|
130
|
+
encode_key, decode_key = jax.random.split(key)
|
131
|
+
else:
|
132
|
+
encode_key = decode_key = None
|
133
|
+
|
134
|
+
# Encode then decode
|
135
|
+
z = self.encode(x, key=encode_key, **kwargs)
|
136
|
+
return self.decode(z, key=decode_key, **kwargs)
|
137
|
+
|
138
|
+
@property
|
139
|
+
def spatial_scale(self) -> int:
|
140
|
+
"""Get the spatial scale factor between input and latent spaces."""
|
141
|
+
return getattr(self, "_spatial_scale", None)
|
142
|
+
|
143
|
+
@property
|
144
|
+
def name(self) -> str:
|
145
|
+
"""Get the name of the autoencoder model."""
|
14
146
|
raise NotImplementedError
|
15
147
|
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
148
|
+
@abstractmethod
|
149
|
+
def serialize(self) -> Dict[str, Any]:
|
150
|
+
"""Serialize the model parameters and configuration."""
|
151
|
+
raise NotImplementedError
|
@@ -22,7 +22,9 @@ class StableDiffusionVAE(AutoEncoder):
|
|
22
22
|
dtype=dtype,
|
23
23
|
)
|
24
24
|
|
25
|
-
|
25
|
+
self.modelname = modelname
|
26
|
+
self.revision = revision
|
27
|
+
self.dtype = dtype
|
26
28
|
|
27
29
|
enc = FlaxEncoder(
|
28
30
|
in_channels=vae.config.in_channels,
|
@@ -63,29 +65,90 @@ class StableDiffusionVAE(AutoEncoder):
|
|
63
65
|
dtype=vae.dtype,
|
64
66
|
)
|
65
67
|
|
66
|
-
|
67
|
-
|
68
|
-
self.post_quant_conv = post_quant_conv
|
69
|
-
self.quant_conv = quant_conv
|
70
|
-
self.params = params
|
71
|
-
self.scaling_factor = vae.scaling_factor
|
68
|
+
scaling_factor = vae.scaling_factor
|
69
|
+
print(f"Scaling factor: {scaling_factor}")
|
72
70
|
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
71
|
+
def encode_single_frame(images, rngkey: jax.random.PRNGKey = None):
|
72
|
+
latents = enc.apply({"params": params['encoder']}, images, deterministic=True)
|
73
|
+
latents = quant_conv.apply({"params": params['quant_conv']}, latents)
|
74
|
+
if rngkey is not None:
|
75
|
+
mean, log_std = jnp.split(latents, 2, axis=-1)
|
76
|
+
log_std = jnp.clip(log_std, -30, 20)
|
77
|
+
std = jnp.exp(0.5 * log_std)
|
78
|
+
latents = mean + std * jax.random.normal(rngkey, mean.shape, dtype=mean.dtype)
|
79
|
+
else:
|
80
|
+
latents, _ = jnp.split(latents, 2, axis=-1)
|
81
|
+
latents *= scaling_factor
|
82
|
+
return latents
|
83
|
+
|
84
|
+
def decode_single_frame(latents):
|
85
|
+
latents = (1.0 / scaling_factor) * latents
|
86
|
+
latents = post_quant_conv.apply({"params": params['post_quant_conv']}, latents)
|
87
|
+
return dec.apply({"params": params['decoder']}, latents)
|
88
|
+
|
89
|
+
self.encode_single_frame = jax.jit(encode_single_frame)
|
90
|
+
self.decode_single_frame = jax.jit(decode_single_frame)
|
91
|
+
|
92
|
+
# Calculate downscale factor by passing a dummy input through the encoder
|
93
|
+
print("Calculating downscale factor...")
|
94
|
+
dummy_input = jnp.ones((1, 128, 128, 3), dtype=dtype)
|
95
|
+
dummy_latents = self.encode_single_frame(dummy_input)
|
96
|
+
_, h, w, c = dummy_latents.shape
|
97
|
+
_, H, W, C = dummy_input.shape
|
98
|
+
self.__downscale_factor__ = H // h
|
99
|
+
self.__latent_channels__ = c
|
100
|
+
print(f"Downscale factor: {self.__downscale_factor__}")
|
101
|
+
print(f"Latent channels: {self.__latent_channels__}")
|
102
|
+
|
103
|
+
def __encode__(self, images, key: jax.random.PRNGKey = None, **kwargs):
|
104
|
+
"""Encode a batch of images to latent representations.
|
105
|
+
|
106
|
+
Implements the abstract method from the parent class.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
images: Image tensor of shape [B, H, W, C]
|
110
|
+
key: Optional random key for stochastic encoding
|
111
|
+
**kwargs: Additional arguments (unused)
|
112
|
+
|
113
|
+
Returns:
|
114
|
+
Latent representations of shape [B, h, w, c]
|
115
|
+
"""
|
116
|
+
return self.encode_single_frame(images, key)
|
117
|
+
|
118
|
+
def __decode__(self, latents, **kwargs):
|
119
|
+
"""Decode latent representations to images.
|
120
|
+
|
121
|
+
Implements the abstract method from the parent class.
|
122
|
+
|
123
|
+
Args:
|
124
|
+
latents: Latent tensor of shape [B, h, w, c]
|
125
|
+
**kwargs: Additional arguments (unused)
|
126
|
+
|
127
|
+
Returns:
|
128
|
+
Decoded images of shape [B, H, W, C]
|
129
|
+
"""
|
130
|
+
return self.decode_single_frame(latents)
|
131
|
+
|
132
|
+
@property
|
133
|
+
def downscale_factor(self) -> int:
|
134
|
+
"""Returns the downscale factor for the encoder."""
|
135
|
+
return self.__downscale_factor__
|
136
|
+
|
137
|
+
@property
|
138
|
+
def latent_channels(self) -> int:
|
139
|
+
"""Returns the number of channels in the latent space."""
|
140
|
+
return self.__latent_channels__
|
141
|
+
|
142
|
+
@property
|
143
|
+
def name(self) -> str:
|
144
|
+
"""Get the name of the autoencoder model."""
|
145
|
+
return "stable_diffusion"
|
87
146
|
|
88
|
-
def
|
89
|
-
|
90
|
-
|
91
|
-
|
147
|
+
def serialize(self):
|
148
|
+
"""Serialize the model to a dictionary format."""
|
149
|
+
return {
|
150
|
+
"modelname": self.modelname,
|
151
|
+
"revision": self.revision,
|
152
|
+
"dtype": str(self.dtype),
|
153
|
+
}
|
154
|
+
|
@@ -6,21 +6,53 @@ from flax.typing import Dtype, PrecisionLike
|
|
6
6
|
from .autoencoder import AutoEncoder
|
7
7
|
|
8
8
|
class SimpleAutoEncoder(AutoEncoder):
|
9
|
+
"""A simple autoencoder implementation using the abstract method pattern.
|
10
|
+
|
11
|
+
This implementation allows for handling both image and video data through
|
12
|
+
the parent class's handling of video reshaping.
|
13
|
+
"""
|
9
14
|
latent_channels: int
|
10
15
|
feature_depths: List[int]=[64, 128, 256, 512]
|
11
|
-
attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}]
|
16
|
+
attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}]
|
12
17
|
num_res_blocks: int=2
|
13
|
-
num_middle_res_blocks:int=1
|
18
|
+
num_middle_res_blocks:int=1
|
14
19
|
activation:Callable = jax.nn.swish
|
15
20
|
norm_groups:int=8
|
16
21
|
dtype: Optional[Dtype] = None
|
17
22
|
precision: PrecisionLike = None
|
18
23
|
|
19
|
-
|
24
|
+
def __encode__(self, x: jnp.ndarray, **kwargs):
|
25
|
+
"""Encode a batch of images to latent representations.
|
26
|
+
|
27
|
+
Implements the abstract method from the parent class.
|
20
28
|
|
29
|
+
Args:
|
30
|
+
x: Image tensor of shape [B, H, W, C]
|
31
|
+
**kwargs: Additional arguments
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
Latent representations of shape [B, h, w, c]
|
35
|
+
"""
|
36
|
+
# TODO: Implement the actual encoding logic for single frames
|
37
|
+
# This is just a placeholder implementation
|
38
|
+
B, H, W, C = x.shape
|
39
|
+
h, w = H // 8, W // 8 # Example downsampling factor
|
40
|
+
return jnp.zeros((B, h, w, self.latent_channels))
|
21
41
|
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
42
|
+
def __decode__(self, z: jnp.ndarray, **kwargs):
|
43
|
+
"""Decode latent representations to images.
|
44
|
+
|
45
|
+
Implements the abstract method from the parent class.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
z: Latent tensor of shape [B, h, w, c]
|
49
|
+
**kwargs: Additional arguments
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
Decoded images of shape [B, H, W, C]
|
53
|
+
"""
|
54
|
+
# TODO: Implement the actual decoding logic for single frames
|
55
|
+
# This is just a placeholder implementation
|
56
|
+
B, h, w, c = z.shape
|
57
|
+
H, W = h * 8, w * 8 # Example upsampling factor
|
58
|
+
return jnp.zeros((B, H, W, 3))
|
flaxdiff/models/simple_unet.py
CHANGED
@@ -10,11 +10,11 @@ from functools import partial
|
|
10
10
|
|
11
11
|
class Unet(nn.Module):
|
12
12
|
output_channels:int=3
|
13
|
-
emb_features:int=64*4
|
14
|
-
feature_depths:list=
|
15
|
-
attention_configs:list=
|
16
|
-
num_res_blocks:int=2
|
17
|
-
num_middle_res_blocks:int=1
|
13
|
+
emb_features:int=64*4
|
14
|
+
feature_depths:list=(64, 128, 256, 512)
|
15
|
+
attention_configs:list=({"heads":8}, {"heads":8}, {"heads":8}, {"heads":8})
|
16
|
+
num_res_blocks:int=2
|
17
|
+
num_middle_res_blocks:int=1
|
18
18
|
activation:Callable = jax.nn.swish
|
19
19
|
norm_groups:int=8
|
20
20
|
dtype: Optional[Dtype] = None
|
flaxdiff/models/simple_vit.py
CHANGED